Skip to content

Commit 65feb50

Browse files
theogfgithub-actions[bot]devmotion
authored
Add init_params keyword argument (#26)
* Update abstractmcmc.jl * Update interface.jl * Update src/interface.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Modified proposal and added basic tests * Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Correct typos * Mentionned keyword argument in docs, patch bump and test fix * Update test/simple.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix tests * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]>
1 parent ca39ed3 commit 65feb50

File tree

4 files changed

+33
-3
lines changed

4 files changed

+33
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EllipticalSliceSampling"
22
uuid = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
33
authors = ["David Widmann <[email protected]>"]
4-
version = "0.4.5"
4+
version = "0.4.6"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/src/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ AbstractMCMC.steps(
5555
gives you access to an iterator from which you can generate an unlimited
5656
number of samples.
5757

58+
You can define the starting point of your chain using the `init_params` keyword argument.
59+
5860
For more details regarding `sample` and `steps` please check the documentation of
5961
[AbstractMCMC.jl](https://github.com/TuringLang/AbstractMCMC.jl).
6062

src/abstractmcmc.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ end
1919

2020
# first step of the elliptical slice sampler
2121
function AbstractMCMC.step(
22-
rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, ::ESS; kwargs...
22+
rng::Random.AbstractRNG,
23+
model::AbstractMCMC.AbstractModel,
24+
::ESS;
25+
init_params=nothing,
26+
kwargs...,
2327
)
2428
# initial sample from the Gaussian prior
25-
f = initial_sample(rng, model)
29+
f = init_params === nothing ? initial_sample(rng, model) : init_params
2630

2731
# compute log-likelihood of the initial sample
2832
loglikelihood = Distributions.loglikelihood(model, f)

test/simple.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434
@test mean(mean, samples) μ atol = 0.05
3535
@test mean(var, samples) σ² atol = 0.05
3636
end
37+
38+
# initial parameter
39+
init_x = randn()
40+
samples = sample(ESSModel(prior, ℓ), ESS(), 10; progress=false, init_params=init_x)
41+
@test first(samples) == init_x
3742
end
3843

3944
@testset "Scalar model with nonzero mean" begin
@@ -62,6 +67,11 @@
6267
@test mean(mean, samples) μ atol = 0.05
6368
@test mean(var, samples) σ² atol = 0.05
6469
end
70+
71+
# initial parameter
72+
init_x = randn()
73+
samples = sample(ESSModel(prior, ℓ), ESS(), 10; progress=false, init_params=init_x)
74+
@test first(samples) == init_x
6575
end
6676

6777
@testset "Scalar model (vectorized)" begin
@@ -91,6 +101,13 @@
91101
@test mean(mean, samples) μ atol = 0.05
92102
@test mean(var, samples) σ² atol = 0.05
93103
end
104+
105+
# initial parameter
106+
init_x = randn(1)
107+
samples = sample(
108+
ESSModel(prior, ℓvec), ESS(), 10; progress=false, init_params=init_x
109+
)
110+
@test first(samples) == init_x
94111
end
95112

96113
@testset "Scalar model with nonzero mean (vectorized)" begin
@@ -120,5 +137,12 @@
120137
@test mean(mean, samples) μ atol = 0.05
121138
@test mean(var, samples) σ² atol = 0.05
122139
end
140+
141+
# initial parameter
142+
init_x = randn(1)
143+
samples = sample(
144+
ESSModel(prior, ℓvec), ESS(), 10; progress=false, init_params=init_x
145+
)
146+
@test first(samples) == init_x
123147
end
124148
end

0 commit comments

Comments
 (0)