-
Notifications
You must be signed in to change notification settings - Fork 0
Composite Sampler #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Composite Sampler #250
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
|
||
from confopt.oneshot.archsampler import BaseSampler | ||
from confopt.oneshot.base import OneShotComponent | ||
|
||
|
||
class CompositeSampler(OneShotComponent): | ||
def __init__( | ||
self, | ||
arch_samplers: list[BaseSampler], | ||
arch_parameters: list[torch.Tensor], | ||
) -> None: | ||
super().__init__() | ||
self.arch_samplers = arch_samplers | ||
self.arch_parameters = arch_parameters | ||
|
||
# get sample frequency from the samplers | ||
self.sample_frequency = arch_samplers[0].sample_frequency | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still use sampling frequency in the code? |
||
for sampler in arch_samplers: | ||
assert ( | ||
self.sample_frequency == sampler.sample_frequency | ||
), "All the sampler must have the same sample frequency" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "All samplers must have" |
||
|
||
def sample(self, alpha: torch.Tensor) -> torch.Tensor: | ||
sampled_alphas = alpha | ||
for sampler in self.arch_samplers: | ||
sampled_alphas = sampler.sample(sampled_alphas) | ||
|
||
return sampled_alphas | ||
|
||
def new_epoch(self) -> None: | ||
super().new_epoch() | ||
for sampler in self.arch_samplers: | ||
sampler.new_epoch() | ||
|
||
def new_step(self) -> None: | ||
super().new_step() | ||
for sampler in self.arch_samplers: | ||
sampler.new_step() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
from abc import ABC | ||
from typing import Any | ||
|
||
from typing_extensions import override | ||
|
||
from confopt.enums import SamplerType, SearchSpaceType | ||
from confopt.searchspace.darts.core.genotypes import DARTSGenotype | ||
from confopt.utils import get_num_classes | ||
|
@@ -402,6 +404,147 @@ def configure_sampler(self, **kwargs) -> None: # type: ignore | |
super().configure_sampler(**kwargs) | ||
|
||
|
||
class CompositeProfile(BaseProfile, ABC): | ||
SAMPLER_TYPE = SamplerType.COMPOSITE | ||
|
||
def __init__( | ||
self, | ||
searchspace_type: str | SearchSpaceType, | ||
samplers: list[str | SamplerType], | ||
epochs: int, | ||
# GDAS configs | ||
tau_min: float = 0.1, | ||
tau_max: float = 10, | ||
# SNAS configs | ||
temp_init: float = 1.0, | ||
temp_min: float = 0.03, | ||
temp_annealing: bool = True, | ||
**kwargs: Any, | ||
) -> None: | ||
self.samplers = [] | ||
for sampler in samplers: | ||
if isinstance(sampler, str): | ||
sampler = SamplerType(sampler) | ||
self.samplers.append(sampler) | ||
|
||
self.tau_min = tau_min | ||
self.tau_max = tau_max | ||
|
||
self.temp_init = temp_init | ||
self.temp_min = temp_min | ||
self.temp_annealing = temp_annealing | ||
|
||
super().__init__( # type: ignore | ||
self.SAMPLER_TYPE, | ||
searchspace_type, | ||
epochs, | ||
**kwargs, | ||
) | ||
|
||
def _initialize_sampler_config(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
EDIT: I see it's called in the base class |
||
"""Initializes the sampler configuration for Composite samplers. | ||
|
||
The sampler configuration includes the sample frequency and the architecture | ||
combine function. | ||
|
||
Args: | ||
None | ||
|
||
Returns: | ||
None | ||
""" | ||
self.sampler_config: dict[int, dict] = {} # type: ignore | ||
for i, sampler in enumerate(self.samplers): | ||
if sampler in [SamplerType.DARTS, SamplerType.DRNAS]: | ||
config = { | ||
"sampler_type": sampler, | ||
"sample_frequency": self.sampler_sample_frequency, | ||
"arch_combine_fn": self.sampler_arch_combine_fn, | ||
} | ||
elif sampler in [SamplerType.GDAS, SamplerType.REINMAX]: | ||
config = { | ||
"sampler_type": sampler, | ||
"sample_frequency": self.sampler_sample_frequency, | ||
"arch_combine_fn": self.sampler_arch_combine_fn, | ||
"tau_min": self.tau_min, | ||
"tau_max": self.tau_max, | ||
} | ||
elif sampler == SamplerType.SNAS: | ||
config = { | ||
"sampler_type": sampler, | ||
"sample_frequency": self.sampler_sample_frequency, | ||
"arch_combine_fn": self.sampler_arch_combine_fn, | ||
"temp_init": self.temp_init, | ||
"temp_min": self.temp_min, | ||
"temp_annealing": self.temp_annealing, | ||
"total_epochs": self.epochs, | ||
} | ||
else: | ||
raise AttributeError(f"Illegal sampler type {sampler} provided!") | ||
|
||
self.sampler_config[i] = config | ||
|
||
@override | ||
def configure_sampler( # type: ignore[override] | ||
self, sampler_config_map: dict[int, dict] | ||
) -> None: | ||
"""Configures the sampler settings based on the provided configurations. | ||
|
||
Args: | ||
sampler_config_map (dict[int, dict]): A dictionary where each key is an \ | ||
integer representing the order of the sampler (zero-indexed), and \ | ||
each value is a dictionary containing the configuration parameters \ | ||
for that sampler. | ||
|
||
The inner configuration dictionary can contain different sets of keys | ||
depending on the type of sampler being configured. The available keys | ||
include: | ||
|
||
Generic Configurations: | ||
- sample_frequency (str): The rate at which samples should be taken. | ||
- arch_combine_fn (str): Function to combine architectures. | ||
For FairDARTS, set this to 'sigmoid'. Default is 'default'. | ||
|
||
GDAS-specific Configurations: | ||
- tau_min (float): Minimum temperature for sampling. | ||
- tau_max (float): Maximum temperature for sampling. | ||
|
||
SNAS-specific Configurations: | ||
- temp_init (float): Initial temperature for sampling. | ||
- temp_min (float): Minimum temperature for sampling. | ||
- temp_annealing (bool): Whether to apply temperature annealing. | ||
- total_epochs (int): Total number of training epochs. | ||
|
||
The specific keys required in the dictionary depend on the type of | ||
sampler being used. | ||
Please make sure that the sample frequency of all the configurations are same. | ||
Each configuration is validated, and an error is raised if unknown keys are | ||
provided. | ||
|
||
Raises: | ||
ValueError: If an unrecognized configuration key is detected. | ||
|
||
Returns: | ||
None | ||
""" | ||
assert self.sampler_config is not None | ||
|
||
for idx in sampler_config_map: | ||
for config_key in sampler_config_map[idx]: | ||
assert idx in self.sampler_config | ||
exists = False | ||
sampler_type = self.sampler_config[idx]["sampler_type"] | ||
if config_key in self.sampler_config[idx]: | ||
exists = True | ||
self.sampler_config[idx][config_key] = sampler_config_map[idx][ | ||
config_key | ||
] | ||
assert exists, ( | ||
f"{config_key} is not a valid configuration for {sampler_type}", | ||
"sampler inside composite sampler", | ||
) | ||
|
||
|
||
class DiscreteProfile: | ||
def __init__( | ||
self, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ | |
) | ||
from confopt.oneshot.archsampler import ( | ||
BaseSampler, | ||
CompositeSampler, | ||
DARTSSampler, | ||
DRNASSampler, | ||
GDASSampler, | ||
|
@@ -374,17 +375,36 @@ def _set_sampler( | |
config: dict, | ||
) -> None: | ||
arch_params = self.search_space.arch_parameters | ||
self.sampler: BaseSampler | None = None | ||
if sampler == SamplerType.DARTS: | ||
self.sampler = DARTSSampler(**config, arch_parameters=arch_params) | ||
elif sampler == SamplerType.DRNAS: | ||
self.sampler = DRNASSampler(**config, arch_parameters=arch_params) | ||
elif sampler == SamplerType.GDAS: | ||
self.sampler = GDASSampler(**config, arch_parameters=arch_params) | ||
elif sampler == SamplerType.SNAS: | ||
self.sampler = SNASSampler(**config, arch_parameters=arch_params) | ||
elif sampler == SamplerType.REINMAX: | ||
self.sampler = ReinMaxSampler(**config, arch_parameters=arch_params) | ||
self.sampler: BaseSampler | CompositeSampler | None = None | ||
|
||
def _get_sampler_class(sampler: SamplerType) -> Callable: | ||
if sampler == SamplerType.DARTS: | ||
return DARTSSampler | ||
if sampler == SamplerType.DRNAS: | ||
return DRNASSampler | ||
if sampler == SamplerType.GDAS: | ||
return GDASSampler | ||
if sampler == SamplerType.SNAS: | ||
return SNASSampler | ||
if sampler == SamplerType.REINMAX: | ||
return ReinMaxSampler | ||
|
||
raise ValueError(f"Illegal sampler {sampler} provided") | ||
|
||
if sampler == SamplerType.COMPOSITE: | ||
sub_samplers: list[BaseSampler] = [] | ||
for _, sampler_config in config.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm not mistaken, |
||
sampler_type = sampler_config["sampler_type"] | ||
del sampler_config["sampler_type"] | ||
sampler_component = _get_sampler_class(sampler_type)( | ||
**sampler_config, arch_parameters=arch_params | ||
) | ||
sub_samplers.append(sampler_component) | ||
self.sampler = CompositeSampler(sub_samplers, arch_parameters=arch_params) | ||
else: | ||
self.sampler = _get_sampler_class(sampler)( | ||
**config, arch_parameters=arch_params | ||
) | ||
|
||
def _set_perturbator( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why inherit from
OneShotComponent
? Why not inherit fromBaseSampler
?