From e0048bf3a2d4074e0db60d2fc61ef62f44e36f2a Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Mon, 13 Jan 2025 20:11:04 +1100 Subject: [PATCH 1/3] ENH: Add jitter_scale parameter for initial point generation (#7555) --- pymc/initial_point.py | 8 +++++++- pymc/sampling/mcmc.py | 4 +++- tests/test_initial_point.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 241409f683..a59d7355d1 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -66,6 +66,7 @@ def make_initial_point_fns_per_chain( model, overrides: StartDict | Sequence[StartDict | None] | None, jitter_rvs: set[TensorVariable] | None = None, + jitter_scale: float = 1.0, chains: int, ) -> list[Callable]: """Create an initial point function for each chain, as defined by initvals. @@ -96,6 +97,7 @@ def make_initial_point_fns_per_chain( model=model, overrides=overrides, jitter_rvs=jitter_rvs, + jitter_scale=jitter_scale, return_transformed=True, ) ] * chains @@ -104,6 +106,7 @@ def make_initial_point_fns_per_chain( make_initial_point_fn( model=model, jitter_rvs=jitter_rvs, + jitter_scale=jitter_scale, overrides=chain_overrides, return_transformed=True, ) @@ -122,6 +125,7 @@ def make_initial_point_fn( model, overrides: StartDict | None = None, jitter_rvs: set[TensorVariable] | None = None, + jitter_scale: float = 1.0, default_strategy: str = "support_point", return_transformed: bool = True, ) -> Callable: @@ -150,6 +154,7 @@ def make_initial_point_fn( rvs_to_transforms=model.rvs_to_transforms, initval_strategies=initval_strats, jitter_rvs=jitter_rvs, + jitter_scale=jitter_scale, default_strategy=default_strategy, return_transformed=return_transformed, ) @@ -188,6 +193,7 @@ def make_initial_point_expression( rvs_to_transforms: dict[TensorVariable, Transform], initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None], jitter_rvs: set[TensorVariable] | None = None, + jitter_scale: float = 1.0, default_strategy: str = "support_point", return_transformed: bool = False, ) -> list[TensorVariable]: @@ -265,7 +271,7 @@ def make_initial_point_expression( value = transform.forward(value, *variable.owner.inputs) if variable in jitter_rvs: - jitter = pt.random.uniform(-1, 1, size=value.shape) + jitter = pt.random.uniform(-jitter_scale, jitter_scale, size=value.shape) jitter.name = f"{variable.name}_jitter" value = value + jitter diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index ca91325ff1..397fc3e317 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1423,10 +1423,11 @@ def _init_jitter( initvals: StartDict | Sequence[StartDict | None] | None, seeds: Sequence[int] | np.ndarray, jitter: bool, + jitter_scale: float, jitter_max_retries: int, logp_dlogp_func=None, ) -> list[PointType]: - """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. + """Apply a uniform jitter in [-jitter_scale, jitter_scale] to the test value as starting point in each chain. ``model.check_start_vals`` is used to test whether the jittered starting values produce a finite log probability. Invalid values are resampled @@ -1449,6 +1450,7 @@ def _init_jitter( model=model, overrides=initvals, jitter_rvs=set(model.free_RVs) if jitter else set(), + jitter_scale=jitter_scale if jitter else 1.0, chains=len(seeds), ) diff --git a/tests/test_initial_point.py b/tests/test_initial_point.py index 9138f37b3e..8f6bc56d29 100644 --- a/tests/test_initial_point.py +++ b/tests/test_initial_point.py @@ -152,6 +152,34 @@ def test_adds_jitter(self): assert fn(0) == fn(0) assert fn(0) != fn(1) + def test_jitter_scale(self): + with pm.Model() as pmodel: + A = pm.HalfFlat("A", initval="support_point") + + jitter_scale_tests = np.array([1.0, 2.0, 5.0]) + fns = [] + for jitter_scale in jitter_scale_tests: + fns.append( + make_initial_point_fn( + model=pmodel, + jitter_rvs=set(pmodel.free_RVs), + jitter_scale=jitter_scale, + return_transformed=True, + ) + ) + + n_draws = 1000 + jitter_samples = np.empty((n_draws, len(fns))) + for j, fn in enumerate(fns): + # start and end to ensure random samples, otherwise jitter_samples across different jitter_scale will be an exact scale of each other + start = j * n_draws + end = start + n_draws + jitter_samples[:, j] = np.asarray([fn(i)["A_log__"] for i in range(start, end)]) + + init_standardised = np.mean((jitter_samples / jitter_scale_tests), axis=0) + + assert np.all((-0.05 < init_standardised) & (init_standardised < 0.05)) + def test_respects_overrides(self): with pm.Model() as pmodel: A = pm.Flat("A", initval="support_point") From b6517af6d9b6d4785009be8108df779cb8d200d7 Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Wed, 15 Jan 2025 21:10:19 +1100 Subject: [PATCH 2/3] Added docstrings and simplified tests --- pymc/initial_point.py | 14 +++++++++----- pymc/sampling/mcmc.py | 2 ++ tests/test_initial_point.py | 33 ++++++++++++--------------------- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index a59d7355d1..f7c7bd5114 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -71,8 +71,7 @@ def make_initial_point_fns_per_chain( ) -> list[Callable]: """Create an initial point function for each chain, as defined by initvals. - If a single initval dictionary is passed, the function is replicated for each - chain, otherwise a unique function is compiled for each entry in the dictionary. + If a single initval dictionary is passed, the function is replicated for each chain, otherwise a unique function is compiled for each entry in the dictionary. Parameters ---------- @@ -82,6 +81,8 @@ def make_initial_point_fns_per_chain( jitter_rvs : set, optional Random variable tensors for which U(-1, 1) jitter shall be applied. (To the transformed space if applicable.) + jitter_scale : float, optional + The scale of the jitter in the jitter_rvs set. Defaults to 1.0. Raises ------ @@ -134,8 +135,9 @@ def make_initial_point_fn( Parameters ---------- jitter_rvs : set - The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be - added to the initial value. Only available for variables that have a transform or real-valued support. + The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support. + jitter_scale : float, optional + The scale of the jitter in the jitter_rvs set. Defaults to 1.0. default_strategy : str Which of { "support_point", "prior" } to prefer if the initval setting for an RV is None. overrides : dict @@ -209,8 +211,10 @@ def make_initial_point_expression( Mapping of free random variable tensors to initial value strategies. For example the `Model.initial_values` dictionary. jitter_rvs : set - The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be + The set (or list or tuple) of random variables for which a U(-1, 1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support. + jitter_scale : float, optional + The scale of the jitter in the jitter_rvs set. Defaults to 1.0. default_strategy : str Which of { "support_point", "prior" } to prefer if the initval strategy setting for an RV is None. return_transformed : bool diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 397fc3e317..0c70639282 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1438,6 +1438,8 @@ def _init_jitter( ---------- jitter: bool Whether to apply jitter or not. + jitter_scale : float, optional + The scale of the jitter in set(model.free_RVs). Defaults to 1.0. jitter_max_retries : int Maximum number of repeated attempts at initializing values (per chain). diff --git a/tests/test_initial_point.py b/tests/test_initial_point.py index 8f6bc56d29..f94a1e7809 100644 --- a/tests/test_initial_point.py +++ b/tests/test_initial_point.py @@ -156,29 +156,20 @@ def test_jitter_scale(self): with pm.Model() as pmodel: A = pm.HalfFlat("A", initval="support_point") - jitter_scale_tests = np.array([1.0, 2.0, 5.0]) - fns = [] - for jitter_scale in jitter_scale_tests: - fns.append( - make_initial_point_fn( - model=pmodel, - jitter_rvs=set(pmodel.free_RVs), - jitter_scale=jitter_scale, - return_transformed=True, - ) - ) - - n_draws = 1000 - jitter_samples = np.empty((n_draws, len(fns))) - for j, fn in enumerate(fns): - # start and end to ensure random samples, otherwise jitter_samples across different jitter_scale will be an exact scale of each other - start = j * n_draws - end = start + n_draws - jitter_samples[:, j] = np.asarray([fn(i)["A_log__"] for i in range(start, end)]) + fn_default = make_initial_point_fn( + model=pmodel, + jitter_rvs=set(pmodel.free_RVs), + return_transformed=True, + ) - init_standardised = np.mean((jitter_samples / jitter_scale_tests), axis=0) + fn_large = make_initial_point_fn( + model=pmodel, + jitter_rvs=set(pmodel.free_RVs), + jitter_scale=1000.0, + return_transformed=True, + ) - assert np.all((-0.05 < init_standardised) & (init_standardised < 0.05)) + assert fn_large(0)["A_log__"] > fn_default(0)["A_log__"] def test_respects_overrides(self): with pm.Model() as pmodel: From c747e63026367e0f8a884c776f2589f260e12f62 Mon Sep 17 00:00:00 2001 From: Michael Cao <177544929+aphc14@users.noreply.github.com> Date: Sun, 19 Jan 2025 18:56:21 +1100 Subject: [PATCH 3/3] Applied code suggestions: jitter_large > 10, jitter_default < 1 Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- tests/test_initial_point.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_initial_point.py b/tests/test_initial_point.py index f94a1e7809..ddc087c14e 100644 --- a/tests/test_initial_point.py +++ b/tests/test_initial_point.py @@ -169,7 +169,8 @@ def test_jitter_scale(self): return_transformed=True, ) - assert fn_large(0)["A_log__"] > fn_default(0)["A_log__"] + assert fn_large(0)["A_log__"] > 10 + assert fn_default(0)["A_log__"] < 1 def test_respects_overrides(self): with pm.Model() as pmodel: