Skip to content

Commit 003ff2f

Browse files
sunxd3github-actions[bot]penelopeysm
authored
Improve error message for initial_params (#772)
* improve error message for `initial_params` * Update src/sampler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix logic error * apply Will's suggestions * bump version * Update src/sampler.jl Co-authored-by: Penelope Yong <[email protected]> * Update Project.toml --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Penelope Yong <[email protected]>
1 parent 3d18cfc commit 003ff2f

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

src/sampler.jl

+31-1
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,17 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
156156
"""
157157
initialsampler(spl::Sampler) = SampleFromPrior()
158158

159+
function set_values!!(
160+
varinfo::AbstractVarInfo, initial_params::AbstractVector, spl::AbstractSampler
161+
)
162+
throw(
163+
ArgumentError(
164+
"`initial_params` must be a vector of type `Union{Real,Missing}`. " *
165+
"If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.",
166+
),
167+
)
168+
end
169+
159170
function set_values!!(
160171
varinfo::AbstractVarInfo,
161172
initial_params::AbstractVector{<:Union{Real,Missing}},
@@ -164,7 +175,8 @@ function set_values!!(
164175
flattened_param_vals = varinfo[spl]
165176
length(flattened_param_vals) == length(initial_params) || throw(
166177
DimensionMismatch(
167-
"Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(flattened_param_vals)))",
178+
"Provided initial value size ($(length(initial_params))) doesn't match " *
179+
"the model size ($(length(flattened_param_vals))).",
168180
),
169181
)
170182

@@ -183,6 +195,24 @@ end
183195
function set_values!!(
184196
varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler
185197
)
198+
vars_in_varinfo = keys(varinfo)
199+
for v in keys(initial_params)
200+
vn = VarName{v}()
201+
if !(vn in vars_in_varinfo)
202+
for vv in vars_in_varinfo
203+
if subsumes(vn, vv)
204+
throw(
205+
ArgumentError(
206+
"The current model contains sub-variables of $v, such as ($vv). " *
207+
"Using NamedTuple for initial_params is not supported in such a case. " *
208+
"Please use AbstractVector for initial_params instead of NamedTuple.",
209+
),
210+
)
211+
end
212+
end
213+
throw(ArgumentError("Variable $v not found in the model."))
214+
end
215+
end
186216
initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing)
187217
return update_values!!(
188218
varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params))

test/sampler.jl

+25
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,30 @@
178178
@test c1[1].metadata.s.vals == c2[1].metadata.s.vals
179179
end
180180
end
181+
182+
@testset "error handling" begin
183+
# https://github.com/TuringLang/Turing.jl/issues/2452
184+
@model function constrained_uniform(n)
185+
Z ~ Uniform(10, 20)
186+
X = Vector{Float64}(undef, n)
187+
for i in 1:n
188+
X[i] ~ Uniform(0, Z)
189+
end
190+
end
191+
192+
n = 2
193+
initial_z = 15
194+
initial_x = [0.2, 0.5]
195+
model = constrained_uniform(n)
196+
vi = VarInfo(model)
197+
198+
@test_throws ArgumentError DynamicPPL.initialize_parameters!!(
199+
vi, [initial_z, initial_x], DynamicPPL.SampleFromPrior(), model
200+
)
201+
202+
@test_throws ArgumentError DynamicPPL.initialize_parameters!!(
203+
vi, (X=initial_x, Z=initial_z), DynamicPPL.SampleFromPrior(), model
204+
)
205+
end
181206
end
182207
end

0 commit comments

Comments
 (0)