Skip to content

Composite Sampler #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/confopt/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SamplerType(Enum):
GDAS = "gdas"
SNAS = "snas"
REINMAX = "reinmax"
COMPOSITE = "composite"

def __str__(self) -> str:
return self.value
Expand Down
2 changes: 2 additions & 0 deletions src/confopt/oneshot/archsampler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_sampler import BaseSampler
from .composite_sampler import CompositeSampler
from .darts.sampler import DARTSSampler
from .drnas.sampler import DRNASSampler
from .gdas.sampler import GDASSampler
Expand All @@ -12,4 +13,5 @@
"GDASSampler",
"SNASSampler",
"ReinMaxSampler",
"CompositeSampler",
]
41 changes: 41 additions & 0 deletions src/confopt/oneshot/archsampler/composite_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

import torch

from confopt.oneshot.archsampler import BaseSampler
from confopt.oneshot.base import OneShotComponent


class CompositeSampler(OneShotComponent):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why inherit from OneShotComponent? Why not inherit from BaseSampler?

def __init__(
self,
arch_samplers: list[BaseSampler],
arch_parameters: list[torch.Tensor],
) -> None:
super().__init__()
self.arch_samplers = arch_samplers
self.arch_parameters = arch_parameters

# get sample frequency from the samplers
self.sample_frequency = arch_samplers[0].sample_frequency
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still use sampling frequency in the code?

for sampler in arch_samplers:
assert (
self.sample_frequency == sampler.sample_frequency
), "All the sampler must have the same sample frequency"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"All samplers must have"


def sample(self, alpha: torch.Tensor) -> torch.Tensor:
sampled_alphas = alpha
for sampler in self.arch_samplers:
sampled_alphas = sampler.sample(sampled_alphas)

return sampled_alphas

def new_epoch(self) -> None:
super().new_epoch()
for sampler in self.arch_samplers:
sampler.new_epoch()

def new_step(self) -> None:
super().new_step()
for sampler in self.arch_samplers:
sampler.new_step()
2 changes: 2 additions & 0 deletions src/confopt/profile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import BaseProfile
from .profiles import (
CompositeProfile,
DARTSProfile,
DiscreteProfile,
DRNASProfile,
Expand All @@ -16,4 +17,5 @@
"SNASProfile",
"DiscreteProfile",
"ReinMaxProfile",
"CompositeProfile",
]
143 changes: 143 additions & 0 deletions src/confopt/profile/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from abc import ABC
from typing import Any

from typing_extensions import override

from confopt.enums import SamplerType, SearchSpaceType
from confopt.searchspace.darts.core.genotypes import DARTSGenotype
from confopt.utils import get_num_classes
Expand Down Expand Up @@ -402,6 +404,147 @@ def configure_sampler(self, **kwargs) -> None: # type: ignore
super().configure_sampler(**kwargs)


class CompositeProfile(BaseProfile, ABC):
SAMPLER_TYPE = SamplerType.COMPOSITE

def __init__(
self,
searchspace_type: str | SearchSpaceType,
samplers: list[str | SamplerType],
epochs: int,
# GDAS configs
tau_min: float = 0.1,
tau_max: float = 10,
# SNAS configs
temp_init: float = 1.0,
temp_min: float = 0.03,
temp_annealing: bool = True,
**kwargs: Any,
) -> None:
self.samplers = []
for sampler in samplers:
if isinstance(sampler, str):
sampler = SamplerType(sampler)
self.samplers.append(sampler)

self.tau_min = tau_min
self.tau_max = tau_max

self.temp_init = temp_init
self.temp_min = temp_min
self.temp_annealing = temp_annealing

super().__init__( # type: ignore
self.SAMPLER_TYPE,
searchspace_type,
epochs,
**kwargs,
)

def _initialize_sampler_config(self) -> None:
Copy link
Contributor

@Neonkraft Neonkraft May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function is defined but never consumed.

EDIT: I see it's called in the base class

"""Initializes the sampler configuration for Composite samplers.

The sampler configuration includes the sample frequency and the architecture
combine function.

Args:
None

Returns:
None
"""
self.sampler_config: dict[int, dict] = {} # type: ignore
for i, sampler in enumerate(self.samplers):
if sampler in [SamplerType.DARTS, SamplerType.DRNAS]:
config = {
"sampler_type": sampler,
"sample_frequency": self.sampler_sample_frequency,
"arch_combine_fn": self.sampler_arch_combine_fn,
}
elif sampler in [SamplerType.GDAS, SamplerType.REINMAX]:
config = {
"sampler_type": sampler,
"sample_frequency": self.sampler_sample_frequency,
"arch_combine_fn": self.sampler_arch_combine_fn,
"tau_min": self.tau_min,
"tau_max": self.tau_max,
}
elif sampler == SamplerType.SNAS:
config = {
"sampler_type": sampler,
"sample_frequency": self.sampler_sample_frequency,
"arch_combine_fn": self.sampler_arch_combine_fn,
"temp_init": self.temp_init,
"temp_min": self.temp_min,
"temp_annealing": self.temp_annealing,
"total_epochs": self.epochs,
}
else:
raise AttributeError(f"Illegal sampler type {sampler} provided!")

self.sampler_config[i] = config

@override
def configure_sampler( # type: ignore[override]
self, sampler_config_map: dict[int, dict]
) -> None:
"""Configures the sampler settings based on the provided configurations.

Args:
sampler_config_map (dict[int, dict]): A dictionary where each key is an \
integer representing the order of the sampler (zero-indexed), and \
each value is a dictionary containing the configuration parameters \
for that sampler.

The inner configuration dictionary can contain different sets of keys
depending on the type of sampler being configured. The available keys
include:

Generic Configurations:
- sample_frequency (str): The rate at which samples should be taken.
- arch_combine_fn (str): Function to combine architectures.
For FairDARTS, set this to 'sigmoid'. Default is 'default'.

GDAS-specific Configurations:
- tau_min (float): Minimum temperature for sampling.
- tau_max (float): Maximum temperature for sampling.

SNAS-specific Configurations:
- temp_init (float): Initial temperature for sampling.
- temp_min (float): Minimum temperature for sampling.
- temp_annealing (bool): Whether to apply temperature annealing.
- total_epochs (int): Total number of training epochs.

The specific keys required in the dictionary depend on the type of
sampler being used.
Please make sure that the sample frequency of all the configurations are same.
Each configuration is validated, and an error is raised if unknown keys are
provided.

Raises:
ValueError: If an unrecognized configuration key is detected.

Returns:
None
"""
assert self.sampler_config is not None

for idx in sampler_config_map:
for config_key in sampler_config_map[idx]:
assert idx in self.sampler_config
exists = False
sampler_type = self.sampler_config[idx]["sampler_type"]
if config_key in self.sampler_config[idx]:
exists = True
self.sampler_config[idx][config_key] = sampler_config_map[idx][
config_key
]
assert exists, (
f"{config_key} is not a valid configuration for {sampler_type}",
"sampler inside composite sampler",
)


class DiscreteProfile:
def __init__(
self,
Expand Down
42 changes: 31 additions & 11 deletions src/confopt/train/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from confopt.oneshot.archsampler import (
BaseSampler,
CompositeSampler,
DARTSSampler,
DRNASSampler,
GDASSampler,
Expand Down Expand Up @@ -374,17 +375,36 @@ def _set_sampler(
config: dict,
) -> None:
arch_params = self.search_space.arch_parameters
self.sampler: BaseSampler | None = None
if sampler == SamplerType.DARTS:
self.sampler = DARTSSampler(**config, arch_parameters=arch_params)
elif sampler == SamplerType.DRNAS:
self.sampler = DRNASSampler(**config, arch_parameters=arch_params)
elif sampler == SamplerType.GDAS:
self.sampler = GDASSampler(**config, arch_parameters=arch_params)
elif sampler == SamplerType.SNAS:
self.sampler = SNASSampler(**config, arch_parameters=arch_params)
elif sampler == SamplerType.REINMAX:
self.sampler = ReinMaxSampler(**config, arch_parameters=arch_params)
self.sampler: BaseSampler | CompositeSampler | None = None

def _get_sampler_class(sampler: SamplerType) -> Callable:
if sampler == SamplerType.DARTS:
return DARTSSampler
if sampler == SamplerType.DRNAS:
return DRNASSampler
if sampler == SamplerType.GDAS:
return GDASSampler
if sampler == SamplerType.SNAS:
return SNASSampler
if sampler == SamplerType.REINMAX:
return ReinMaxSampler

raise ValueError(f"Illegal sampler {sampler} provided")

if sampler == SamplerType.COMPOSITE:
sub_samplers: list[BaseSampler] = []
for _, sampler_config in config.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, config.items() might not return the values in a specific order. This might mean that multiple instantiations of the CompositeSampler with the same samplers might have the samplers in different order.

sampler_type = sampler_config["sampler_type"]
del sampler_config["sampler_type"]
sampler_component = _get_sampler_class(sampler_type)(
**sampler_config, arch_parameters=arch_params
)
sub_samplers.append(sampler_component)
self.sampler = CompositeSampler(sub_samplers, arch_parameters=arch_params)
else:
self.sampler = _get_sampler_class(sampler)(
**config, arch_parameters=arch_params
)

def _set_perturbator(
self,
Expand Down
9 changes: 7 additions & 2 deletions src/confopt/train/search_space_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import torch

from confopt.oneshot.archsampler import BaseSampler, DARTSSampler, GDASSampler
from confopt.oneshot.archsampler import (
BaseSampler,
CompositeSampler,
DARTSSampler,
GDASSampler,
)
from confopt.oneshot.dropout import Dropout
from confopt.oneshot.lora_toggler import LoRAToggler
from confopt.oneshot.partial_connector import PartialConnector
Expand All @@ -23,7 +28,7 @@
class SearchSpaceHandler:
def __init__(
self,
sampler: BaseSampler,
sampler: BaseSampler | CompositeSampler,
edge_normalization: bool = False,
partial_connector: PartialConnector | None = None,
perturbation: BasePerturbator | None = None,
Expand Down
55 changes: 55 additions & 0 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
DRNASSampler,
GDASSampler,
ReinMaxSampler,
CompositeSampler,
)


Expand Down Expand Up @@ -111,5 +112,59 @@ def test_post_sample_fn_sigmoid(self) -> None:
_test_arch_combine_fn_default(sampler, alphas)


class TestCompositeSampler(unittest.TestCase):
def test_post_sample_fn_default(self) -> None:
alphas = [torch.randn(14, 8), torch.randn(14, 8)]
inner_samplers = [
DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
]
sampler = CompositeSampler(inner_samplers, alphas)
_test_arch_combine_fn_default(sampler, alphas)

def test_post_sample_fn_sigmoid(self) -> None:
alphas = [torch.randn(14, 8), torch.randn(14, 8)]
inner_samplers = [
DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
]
sampler = CompositeSampler(inner_samplers, alphas)
_test_arch_combine_fn_sigmoid(sampler, alphas)

def test_post_sample_fn_mixed(self) -> None:
alphas = [torch.randn(14, 8), torch.randn(14, 8)]

# Case 1
inner_samplers = [
DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
]
sampler = CompositeSampler(inner_samplers, alphas)
# The last sampler's combine function matters,
# and if its GDAS/Reinmax, it would get ignored
_test_arch_combine_fn_default(sampler, alphas)

# Case 2
inner_samplers = [
DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
]
sampler = CompositeSampler(inner_samplers, alphas)
_test_arch_combine_fn_default(sampler, alphas)

# Case 3
inner_samplers = [
DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
]
sampler = CompositeSampler(inner_samplers, alphas)
_test_arch_combine_fn_sigmoid(sampler, alphas)


if __name__ == "__main__":
unittest.main()