Skip to content

Commit 1366440

Browse files
mhaurusunxd3penelopeysm
authored
Remove samplers from VarInfo - indexing (#793)
* Remove selector stuff from varinfo tests * Implement link and invlink for varnames rather than samplers * Replace set_retained_vns_del_by_spl! with set_retained_vns_del! * Make linking tests more extensive * Remove sampler indexing from link methods (but not invlink) * Remove indexing by samplers from invlink * Work towards removing sampler indexing with StaticTransformation * Fix invlink/link for TypedVarInfo and StaticTransformation * Fix a test in models.jl * Move some functions to utils.jl, add tests and docstrings * Fix a docstring typo * Various simplification to link/invlink * Improve a docstring * Style improvements * Fix broken link/invlink dispatch cascade for VectorVarInfo * Fix some more broken dispatch cascades * Apply suggestions from code review Co-authored-by: Xianda Sun <[email protected]> * Remove comments that messed with docstrings * Apply suggestions from code review Co-authored-by: Penelope Yong <[email protected]> * Fix issues surfaced in code review * Simplify link/invlink arguments * Fix a bug in unflatten VarNamedVector * Rename VarNameCollection -> VarNameTuple * Remove test of a removed varname_namedtuple method * Apply suggestions from code review Co-authored-by: Penelope Yong <[email protected]> * Respond to review feedback * Remove _default_sampler and a dead argument of maybe_invlink_before_eval * Fix a typo in a comment * Add HISTORY entry, fix one set_retained_vns_del! method * Remove some VarInfo getindex with samplers stuff * Remove some index setting with samplers * Remove more sampler indexing * Remove unflatten with samplers * Clean up some setindex stuff * Remove a bunch of varinfo.jl internal functions that used samplers/space, update HISTORY.md * Fix HISTORY.md * Miscalleanous small fixes * Fix a bug in VarInfo constructor * Fix getparams(::LogDensityFunction) * Apply suggestions from code review Co-authored-by: Penelope Yong <[email protected]> --------- Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Penelope Yong <[email protected]>
1 parent c5f2f7a commit 1366440

13 files changed

+101
-401
lines changed

HISTORY.md

+5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ This release removes the feature of `VarInfo` where it kept track of which varia
1010

1111
- `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`.
1212
- `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables.
13+
- `getindex`, `setindex!`, and `setindex!!` no longer accept samplers as arguments
14+
- `unflatten` no longer accepts a sampler as an argument
15+
- `eltype(::VarInfo)` no longer accepts a sampler as an argument
16+
- `keys(::VarInfo)` no longer accepts a sampler as an argument
17+
- `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument.
1318

1419
### Reverse prefixing order
1520

src/abstract_varinfo.jl

+8-24
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ If `dist` is specified, the value(s) will be massaged into the representation ex
149149

150150
"""
151151
getindex(vi::AbstractVarInfo, ::Colon)
152-
getindex(vi::AbstractVarInfo, ::AbstractSampler)
153152
154153
Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their)
155154
distribution(s) as a flattened `Vector`.
@@ -159,7 +158,6 @@ The default implementation is to call [`values_as`](@ref) with `Vector` as the t
159158
See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref)
160159
"""
161160
Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector)
162-
Base.getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:]
163161

164162
"""
165163
getindex_internal(vi::AbstractVarInfo, vn::VarName)
@@ -341,9 +339,9 @@ julia> values_as(vi, Vector)
341339
function values_as end
342340

343341
"""
344-
eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}
342+
eltype(vi::AbstractVarInfo)
345343
346-
Determine the default `eltype` of the values returned by `vi[spl]`.
344+
Return the `eltype` of the values returned by `vi[:]`.
347345
348346
!!! warning
349347
This should generally not be called explicitly, as it's only used in
@@ -352,13 +350,13 @@ Determine the default `eltype` of the values returned by `vi[spl]`.
352350
353351
This method is considered legacy, and is likely to be deprecated in the future.
354352
"""
355-
function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior})
356-
T = Base.promote_op(getindex, typeof(vi), typeof(spl))
353+
function Base.eltype(vi::AbstractVarInfo)
354+
T = Base.promote_op(getindex, typeof(vi), Colon)
357355
if T === Union{}
358-
# In this case `getindex(vi, spl)` errors
356+
# In this case `getindex(vi, :)` errors
359357
# Let us throw a more descriptive error message
360358
# Ref https://github.com/TuringLang/Turing.jl/issues/2151
361-
return eltype(vi[spl])
359+
return eltype(vi[:])
362360
end
363361
return eltype(T)
364362
end
@@ -720,25 +718,11 @@ end
720718

721719
# Utilities
722720
"""
723-
unflatten(vi::AbstractVarInfo[, context::AbstractContext], x::AbstractVector)
721+
unflatten(vi::AbstractVarInfo, x::AbstractVector)
724722
725723
Return a new instance of `vi` with the values of `x` assigned to the variables.
726-
727-
If `context` is provided, `x` is assumed to be realizations only for variables not
728-
filtered out by `context`.
729724
"""
730-
function unflatten(varinfo::AbstractVarInfo, context::AbstractContext, θ)
731-
if hassampler(context)
732-
unflatten(getsampler(context), varinfo, context, θ)
733-
else
734-
DynamicPPL.unflatten(varinfo, θ)
735-
end
736-
end
737-
738-
# TODO: deprecate this once `sampler` is no longer the main way of filtering out variables.
739-
function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::AbstractContext, θ)
740-
return unflatten(varinfo, sampler, θ)
741-
end
725+
function unflatten end
742726

743727
"""
744728
to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)

src/compiler.jl

+25-40
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
33
"""
44
need_concretize(expr)
55
6-
Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
6+
Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
77
requires a dynamic optic.
88
99
# Examples
@@ -730,19 +730,19 @@ function warn_empty(body)
730730
return nothing
731731
end
732732

733+
# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why?
734+
# TODO(mhauru) This function needs a more comprehensive docstring.
733735
"""
734-
matchingvalue(sampler, vi, value)
735-
matchingvalue(context::AbstractContext, vi, value)
736+
matchingvalue(vi, value)
736737
737-
Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object.
738-
739-
For a `context` that is _not_ a `SamplingContext`, we fall back to
740-
`matchingvalue(SampleFromPrior(), vi, value)`.
738+
Convert the `value` to the correct type for the `vi` object.
741739
"""
742-
function matchingvalue(sampler, vi, value)
740+
function matchingvalue(vi, value)
743741
T = typeof(value)
744742
if hasmissing(T)
745-
_value = convert(get_matching_type(sampler, vi, T), value)
743+
_value = convert(get_matching_type(vi, T), value)
744+
# TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we
745+
# are happy to return `value` as-is?
746746
if _value === value
747747
return deepcopy(_value)
748748
else
@@ -752,45 +752,30 @@ function matchingvalue(sampler, vi, value)
752752
return value
753753
end
754754
end
755-
# If we hit `Type` or `TypeWrap`, we immediately jump to `get_matching_type`.
756-
function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType)
757-
return get_matching_type(sampler, vi, value)
758-
end
759-
function matchingvalue(sampler::AbstractSampler, vi, value::TypeWrap{T}) where {T}
760-
return TypeWrap{get_matching_type(sampler, vi, T)}()
761-
end
762755

763-
function matchingvalue(context::AbstractContext, vi, value)
764-
return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value)
756+
function matchingvalue(vi, value::FloatOrArrayType)
757+
return get_matching_type(vi, value)
765758
end
766-
function matchingvalue(::IsLeaf, context::AbstractContext, vi, value)
767-
return matchingvalue(SampleFromPrior(), vi, value)
768-
end
769-
function matchingvalue(::IsParent, context::AbstractContext, vi, value)
770-
return matchingvalue(childcontext(context), vi, value)
771-
end
772-
function matchingvalue(context::SamplingContext, vi, value)
773-
return matchingvalue(context.sampler, vi, value)
759+
function matchingvalue(vi, ::TypeWrap{T}) where {T}
760+
return TypeWrap{get_matching_type(vi, T)}()
774761
end
775762

763+
# TODO(mhauru) This function needs a more comprehensive docstring. What is it for?
776764
"""
777-
get_matching_type(spl::AbstractSampler, vi, ::TypeWrap{T}) where {T}
778-
779-
Get the specialized version of type `T` for sampler `spl`.
765+
get_matching_type(vi, ::TypeWrap{T}) where {T}
780766
781-
For example, if `T === Float64` and `spl::Hamiltonian`, the matching type is
782-
`eltype(vi[spl])`.
767+
Get the specialized version of type `T` for `vi`.
783768
"""
784-
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} = T
785-
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Union{Missing,AbstractFloat}})
786-
return Union{Missing,float_type_with_fallback(eltype(vi, spl))}
769+
get_matching_type(_, ::Type{T}) where {T} = T
770+
function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}})
771+
return Union{Missing,float_type_with_fallback(eltype(vi))}
787772
end
788-
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:AbstractFloat})
789-
return float_type_with_fallback(eltype(vi, spl))
773+
function get_matching_type(vi, ::Type{<:AbstractFloat})
774+
return float_type_with_fallback(eltype(vi))
790775
end
791-
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N}
792-
return Array{get_matching_type(spl, vi, T),N}
776+
function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N}
777+
return Array{get_matching_type(vi, T),N}
793778
end
794-
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where {T}
795-
return Array{get_matching_type(spl, vi, T)}
779+
function get_matching_type(vi, ::Type{<:Array{T}}) where {T}
780+
return Array{get_matching_type(vi, T)}
796781
end

src/logdensityfunction.jl

+2-7
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,17 @@ end
121121
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
122122
hassampler(f::LogDensityFunction) = hassampler(getcontext(f))
123123

124-
_get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx)
125-
_get_indexer(ctx::SamplingContext) = ctx.sampler
126-
_get_indexer(::IsParent, ctx::AbstractContext) = _get_indexer(childcontext(ctx))
127-
_get_indexer(::IsLeaf, ctx::AbstractContext) = Colon()
128-
129124
"""
130125
getparams(f::LogDensityFunction)
131126
132127
Return the parameters of the wrapped varinfo as a vector.
133128
"""
134-
getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))]
129+
getparams(f::LogDensityFunction) = f.varinfo[:]
135130

136131
# LogDensityProblems interface
137132
function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector)
138133
context = getcontext(f)
139-
vi_new = unflatten(f.varinfo, context, θ)
134+
vi_new = unflatten(f.varinfo, θ)
140135
return getlogp(last(evaluate!!(f.model, vi_new, context)))
141136
end
142137
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})

src/model.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -948,9 +948,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
948948
) where {_F,argnames}
949949
unwrap_args = [
950950
if is_splat_symbol(var)
951-
:($matchingvalue(context_new, varinfo, model.args.$var)...)
951+
:($matchingvalue(varinfo, model.args.$var)...)
952952
else
953-
:($matchingvalue(context_new, varinfo, model.args.$var))
953+
:($matchingvalue(varinfo, model.args.$var))
954954
end for var in argnames
955955
]
956956

src/sampler.jl

+12-19
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ function AbstractMCMC.step(
118118

119119
# Update the parameters if provided.
120120
if initial_params !== nothing
121-
vi = initialize_parameters!!(vi, initial_params, spl, model)
121+
vi = initialize_parameters!!(vi, initial_params, model)
122122

123123
# Update joint log probability.
124124
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
@@ -156,9 +156,7 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
156156
"""
157157
initialsampler(spl::Sampler) = SampleFromPrior()
158158

159-
function set_values!!(
160-
varinfo::AbstractVarInfo, initial_params::AbstractVector, spl::AbstractSampler
161-
)
159+
function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector)
162160
throw(
163161
ArgumentError(
164162
"`initial_params` must be a vector of type `Union{Real,Missing}`. " *
@@ -168,11 +166,9 @@ function set_values!!(
168166
end
169167

170168
function set_values!!(
171-
varinfo::AbstractVarInfo,
172-
initial_params::AbstractVector{<:Union{Real,Missing}},
173-
spl::AbstractSampler,
169+
varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
174170
)
175-
flattened_param_vals = varinfo[spl]
171+
flattened_param_vals = varinfo[:]
176172
length(flattened_param_vals) == length(initial_params) || throw(
177173
DimensionMismatch(
178174
"Provided initial value size ($(length(initial_params))) doesn't match " *
@@ -189,12 +185,11 @@ function set_values!!(
189185
end
190186

191187
# Update in `varinfo`.
192-
return setindex!!(varinfo, flattened_param_vals, spl)
188+
setall!(varinfo, flattened_param_vals)
189+
return varinfo
193190
end
194191

195-
function set_values!!(
196-
varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler
197-
)
192+
function set_values!!(varinfo::AbstractVarInfo, initial_params::NamedTuple)
198193
vars_in_varinfo = keys(varinfo)
199194
for v in keys(initial_params)
200195
vn = VarName{v}()
@@ -219,23 +214,21 @@ function set_values!!(
219214
)
220215
end
221216

222-
function initialize_parameters!!(
223-
vi::AbstractVarInfo, initial_params, spl::AbstractSampler, model::Model
224-
)
217+
function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model)
225218
@debug "Using passed-in initial variable values" initial_params
226219

227220
# `link` the varinfo if needed.
228-
linked = islinked(vi, spl)
221+
linked = islinked(vi)
229222
if linked
230-
vi = invlink!!(vi, spl, model)
223+
vi = invlink!!(vi, model)
231224
end
232225

233226
# Set the values in `vi`.
234-
vi = set_values!!(vi, initial_params, spl)
227+
vi = set_values!!(vi, initial_params)
235228

236229
# `invlink` if needed.
237230
if linked
238-
vi = link!!(vi, spl, model)
231+
vi = link!!(vi, model)
239232
end
240233

241234
return vi

src/simple_varinfo.jl

+2-11
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ function typed_simple_varinfo(model::Model)
258258
return last(evaluate!!(model, varinfo, SamplingContext()))
259259
end
260260

261-
unflatten(svi::SimpleVarInfo, spl::AbstractSampler, x::AbstractVector) = unflatten(svi, x)
262261
function unflatten(svi::SimpleVarInfo, x::AbstractVector)
263262
logp = getlogp(svi)
264263
vals = unflatten(svi.values, x)
@@ -342,10 +341,6 @@ function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName)
342341
return Accessors.@set vi.values = set!!(vi.values, vn, val)
343342
end
344343

345-
function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler)
346-
return unflatten(vi, spl, val)
347-
end
348-
349344
# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with
350345
# same symbol and same type of, say, `IndexLens`, for improved `.~` performance.
351346
function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName})
@@ -428,11 +423,7 @@ const SimpleOrThreadSafeSimple{T,V,C} = Union{
428423
}
429424

430425
# Necessary for `matchingvalue` to work properly.
431-
function Base.eltype(
432-
vi::SimpleOrThreadSafeSimple{<:Any,V}, spl::Union{AbstractSampler,SampleFromPrior}
433-
) where {V}
434-
return V
435-
end
426+
Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V
436427

437428
# `subset`
438429
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
@@ -562,7 +553,7 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
562553
istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi)
563554
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)
564555

565-
islinked(vi::SimpleVarInfo, ::Union{Sampler,SampleFromPrior}) = istrans(vi)
556+
islinked(vi::SimpleVarInfo) = istrans(vi)
566557

567558
values_as(vi::SimpleVarInfo) = vi.values
568559
values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values

src/threadsafe.jl

+1-19
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)
7979
keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)
8080
haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)
8181

82-
islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl)
82+
islinked(vi::ThreadSafeVarInfo) = islinked(vi.varinfo)
8383

8484
function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
8585
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...)
@@ -138,17 +138,6 @@ end
138138
function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution)
139139
return getindex(vi.varinfo, vns, dist)
140140
end
141-
getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl)
142-
143-
function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler)
144-
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl)
145-
end
146-
function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior)
147-
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl)
148-
end
149-
function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform)
150-
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl)
151-
end
152141

153142
function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName)
154143
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn)
@@ -184,13 +173,9 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
184173
return is_flagged(vi.varinfo, vn, flag)
185174
end
186175

187-
# Transformations.
188176
function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName)
189177
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn)
190178
end
191-
function settrans!!(vi::ThreadSafeVarInfo, spl::AbstractSampler, dist::Distribution)
192-
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist)
193-
end
194179

195180
istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
196181
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)
@@ -200,9 +185,6 @@ getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.var
200185
function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector)
201186
return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x)
202187
end
203-
function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector)
204-
return Accessors.@set vi.varinfo = unflatten(vi.varinfo, spl, x)
205-
end
206188

207189
function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName})
208190
return Accessors.@set varinfo.varinfo = subset(varinfo.varinfo, vns)

0 commit comments

Comments
 (0)