Skip to content

Commit c5f2f7a

Browse files
mhaurusunxd3penelopeysm
authored
Remove selector stuff from VarInfo tests and link/invlink (#780)
* Remove selector stuff from varinfo tests * Implement link and invlink for varnames rather than samplers * Replace set_retained_vns_del_by_spl! with set_retained_vns_del! * Make linking tests more extensive * Remove sampler indexing from link methods (but not invlink) * Remove indexing by samplers from invlink * Work towards removing sampler indexing with StaticTransformation * Fix invlink/link for TypedVarInfo and StaticTransformation * Fix a test in models.jl * Move some functions to utils.jl, add tests and docstrings * Fix a docstring typo * Various simplification to link/invlink * Improve a docstring * Style improvements * Fix broken link/invlink dispatch cascade for VectorVarInfo * Fix some more broken dispatch cascades * Apply suggestions from code review Co-authored-by: Xianda Sun <[email protected]> * Remove comments that messed with docstrings * Apply suggestions from code review Co-authored-by: Penelope Yong <[email protected]> * Fix issues surfaced in code review * Simplify link/invlink arguments * Fix a bug in unflatten VarNamedVector * Rename VarNameCollection -> VarNameTuple * Remove test of a removed varname_namedtuple method * Apply suggestions from code review Co-authored-by: Penelope Yong <[email protected]> * Respond to review feedback * Remove _default_sampler and a dead argument of maybe_invlink_before_eval * Fix a typo in a comment * Add HISTORY entry, fix one set_retained_vns_del! method --------- Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Penelope Yong <[email protected]>
1 parent 7140f3d commit c5f2f7a

15 files changed

+513
-501
lines changed

HISTORY.md

+9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44

55
**Breaking**
66

7+
### Remove indexing by samplers
8+
9+
This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular,
10+
11+
- `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`.
12+
- `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables.
13+
14+
### Reverse prefixing order
15+
716
- For submodels constructed using `to_submodel`, the order in which nested prefixes are applied has been changed.
817
Previously, the order was that outer prefixes were applied first, then inner ones.
918
This version reverses that.

docs/src/api.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ set_num_produce!
304304
increment_num_produce!
305305
reset_num_produce!
306306
setorder!
307-
set_retained_vns_del_by_spl!
307+
set_retained_vns_del!
308308
```
309309

310310
```@docs

src/DynamicPPL.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ export AbstractVarInfo,
5959
set_num_produce!,
6060
reset_num_produce!,
6161
increment_num_produce!,
62-
set_retained_vns_del_by_spl!,
62+
set_retained_vns_del!,
6363
is_flagged,
6464
set_flag!,
6565
unset_flag!,

src/abstract_varinfo.jl

+60-70
Original file line numberDiff line numberDiff line change
@@ -537,117 +537,118 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl
537537
"""
538538
function settrans!! end
539539

540+
# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback
541+
# method for the case when no `vns` is provided, that would get all the keys from the
542+
# `VarInfo`. Hence each subtype of `AbstractVarInfo` needs to implement separately the case
543+
# where `vns` is provided and the one where it is not. This is because having separate
544+
# implementations is typically much more performant, and because not all AbstractVarInfo
545+
# types support partial linking.
546+
540547
"""
541548
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
542-
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
549+
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
550+
551+
Transform variables in `vi` to their linked space, mutating `vi` if possible.
543552
544-
Transform the variables in `vi` to their linked space, using the transformation `t`,
545-
mutating `vi` if possible.
553+
Either transform all variables, or only ones specified in `vns`.
546554
547-
If `t` is not provided, `default_transformation(model, vi)` will be used.
555+
Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided.
548556
549557
See also: [`default_transformation`](@ref), [`invlink!!`](@ref).
550558
"""
551-
link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model)
552-
function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
553-
return link!!(t, vi, SampleFromPrior(), model)
559+
function link!!(vi::AbstractVarInfo, model::Model)
560+
return link!!(default_transformation(model, vi), vi, model)
554561
end
555-
function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
556-
# Use `default_transformation` to decide which transformation to use if none is specified.
557-
return link!!(default_transformation(model, vi), vi, spl, model)
562+
function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
563+
return link!!(default_transformation(model, vi), vi, vns, model)
558564
end
559565

560566
"""
561567
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
562-
link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
568+
link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
569+
570+
Transform variables in `vi` to their linked space without mutating `vi`.
563571
564-
Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`.
572+
Either transform all variables, or only ones specified in `vns`.
565573
566-
If `t` is not provided, `default_transformation(model, vi)` will be used.
574+
Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided.
567575
568576
See also: [`default_transformation`](@ref), [`invlink`](@ref).
569577
"""
570-
link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model)
571-
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
572-
return link(t, deepcopy(vi), SampleFromPrior(), model)
578+
function link(vi::AbstractVarInfo, model::Model)
579+
return link(default_transformation(model, vi), vi, model)
573580
end
574-
function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
575-
# Use `default_transformation` to decide which transformation to use if none is specified.
576-
return link(default_transformation(model, vi), deepcopy(vi), spl, model)
581+
function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
582+
return link(default_transformation(model, vi), vi, vns, model)
577583
end
578584

579585
"""
580586
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
581-
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
587+
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
582588
583-
Transform the variables in `vi` to their constrained space, using the (inverse of)
584-
transformation `t`, mutating `vi` if possible.
589+
Transform variables in `vi` to their constrained space, mutating `vi` if possible.
585590
586-
If `t` is not provided, `default_transformation(model, vi)` will be used.
591+
Either transform all variables, or only ones specified in `vns`.
592+
593+
Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is
594+
not provided.
587595
588596
See also: [`default_transformation`](@ref), [`link!!`](@ref).
589597
"""
590-
invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model)
591-
function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
592-
return invlink!!(t, vi, SampleFromPrior(), model)
598+
function invlink!!(vi::AbstractVarInfo, model::Model)
599+
return invlink!!(default_transformation(model, vi), vi, model)
593600
end
594-
function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
595-
# Here we extract the `transformation` from `vi` rather than using the default one.
596-
return invlink!!(transformation(vi), vi, spl, model)
601+
function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
602+
return invlink!!(default_transformation(model, vi), vi, vns, model)
597603
end
598604

599605
# Vector-based ones.
600606
function link!!(
601-
t::StaticTransformation{<:Bijectors.Transform},
602-
vi::AbstractVarInfo,
603-
spl::AbstractSampler,
604-
model::Model,
607+
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
605608
)
606609
b = inverse(t.bijector)
607-
x = vi[spl]
610+
x = vi[:]
608611
y, logjac = with_logabsdet_jacobian(b, x)
609612

610613
lp_new = getlogp(vi) - logjac
611-
vi_new = setlogp!!(unflatten(vi, spl, y), lp_new)
614+
vi_new = setlogp!!(unflatten(vi, y), lp_new)
612615
return settrans!!(vi_new, t)
613616
end
614617

615618
function invlink!!(
616-
t::StaticTransformation{<:Bijectors.Transform},
617-
vi::AbstractVarInfo,
618-
spl::AbstractSampler,
619-
model::Model,
619+
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
620620
)
621621
b = t.bijector
622-
y = vi[spl]
622+
y = vi[:]
623623
x, logjac = with_logabsdet_jacobian(b, y)
624624

625625
lp_new = getlogp(vi) + logjac
626-
vi_new = setlogp!!(unflatten(vi, spl, x), lp_new)
626+
vi_new = setlogp!!(unflatten(vi, x), lp_new)
627627
return settrans!!(vi_new, NoTransformation())
628628
end
629629

630630
"""
631631
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
632-
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
632+
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
633+
634+
Transform variables in `vi` to their constrained space without mutating `vi`.
633635
634-
Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of)
635-
transformation `t`.
636+
Either transform all variables, or only ones specified in `vns`.
636637
637-
If `t` is not provided, `default_transformation(model, vi)` will be used.
638+
Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is
639+
not provided.
638640
639641
See also: [`default_transformation`](@ref), [`link`](@ref).
640642
"""
641-
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)
642-
function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
643-
return invlink(t, vi, SampleFromPrior(), model)
643+
function invlink(vi::AbstractVarInfo, model::Model)
644+
return invlink(default_transformation(model, vi), vi, model)
644645
end
645-
function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
646-
return invlink(transformation(vi), vi, spl, model)
646+
function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
647+
return invlink(default_transformation(model, vi), vi, vns, model)
647648
end
648649

649650
"""
650-
maybe_invlink_before_eval!!([t::Transformation,] vi, context, model)
651+
maybe_invlink_before_eval!!([t::Transformation,] vi, model)
651652
652653
Return a possibly invlinked version of `vi`.
653654
@@ -698,34 +699,23 @@ julia> # Now performs a single `invlink!!` before model evaluation.
698699
-1001.4189385332047
699700
```
700701
"""
701-
function maybe_invlink_before_eval!!(
702-
vi::AbstractVarInfo, context::AbstractContext, model::Model
703-
)
704-
return maybe_invlink_before_eval!!(transformation(vi), vi, context, model)
702+
function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model)
703+
return maybe_invlink_before_eval!!(transformation(vi), vi, model)
705704
end
706-
function maybe_invlink_before_eval!!(
707-
::NoTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model
708-
)
705+
function maybe_invlink_before_eval!!(::NoTransformation, vi::AbstractVarInfo, model::Model)
709706
return vi
710707
end
711708
function maybe_invlink_before_eval!!(
712-
::DynamicTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model
709+
::DynamicTransformation, vi::AbstractVarInfo, model::Model
713710
)
714-
# `DynamicTransformation` is meant to _not_ do the transformation statically, hence we do nothing.
711+
# `DynamicTransformation` is meant to _not_ do the transformation statically, hence we
712+
# do nothing.
715713
return vi
716714
end
717715
function maybe_invlink_before_eval!!(
718-
t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model
716+
t::StaticTransformation, vi::AbstractVarInfo, model::Model
719717
)
720-
return invlink!!(t, vi, _default_sampler(context), model)
721-
end
722-
723-
function _default_sampler(context::AbstractContext)
724-
return _default_sampler(NodeTrait(_default_sampler, context), context)
725-
end
726-
_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior()
727-
function _default_sampler(::IsParent, context::AbstractContext)
728-
return _default_sampler(childcontext(context))
718+
return invlink!!(t, vi, model)
729719
end
730720

731721
# Utilities

src/model.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
971971
# lazy `invlink`-ing of the parameters. This can be useful for
972972
# speeding up computation. See docs for `maybe_invlink_before_eval!!`
973973
# for more information.
974-
maybe_invlink_before_eval!!(varinfo, context_new, model),
974+
maybe_invlink_before_eval!!(varinfo, model),
975975
context_new,
976976
$(unwrap_args...),
977977
)
@@ -1169,10 +1169,10 @@ end
11691169
"""
11701170
predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
11711171
1172-
Generate samples from the posterior predictive distribution by evaluating `model` at each set
1173-
of parameter values provided in `chain`. The number of posterior predictive samples matches
1172+
Generate samples from the posterior predictive distribution by evaluating `model` at each set
1173+
of parameter values provided in `chain`. The number of posterior predictive samples matches
11741174
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
1175-
and the predicted values.
1175+
and the predicted values.
11761176
"""
11771177
function predict(
11781178
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}

src/simple_varinfo.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -680,8 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn
680680
function link!!(
681681
t::StaticTransformation{<:Bijectors.NamedTransform},
682682
vi::SimpleVarInfo{<:NamedTuple},
683-
spl::AbstractSampler,
684-
model::Model,
683+
::Model,
685684
)
686685
# TODO: Make sure that `spl` is respected.
687686
b = inverse(t.bijector)
@@ -695,8 +694,7 @@ end
695694
function invlink!!(
696695
t::StaticTransformation{<:Bijectors.NamedTransform},
697696
vi::SimpleVarInfo{<:NamedTuple},
698-
spl::AbstractSampler,
699-
model::Model,
697+
::Model,
700698
)
701699
# TODO: Make sure that `spl` is respected.
702700
b = t.bijector

src/threadsafe.jl

+21-40
Original file line numberDiff line numberDiff line change
@@ -81,70 +81,51 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)
8181

8282
islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl)
8383

84-
function link!!(
85-
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
86-
)
87-
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl, model)
84+
function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
85+
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...)
8886
end
8987

90-
function invlink!!(
91-
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
92-
)
93-
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model)
88+
function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
89+
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...)
9490
end
9591

96-
function link(
97-
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
98-
)
99-
return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl, model)
92+
function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
93+
return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...)
10094
end
10195

102-
function invlink(
103-
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
104-
)
105-
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl, model)
96+
function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
97+
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...)
10698
end
10799

108100
# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
109101
# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure
110102
# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates
111103
# to define `getlogp(vi)`.
112-
function link!!(
113-
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
114-
)
104+
function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
115105
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
116106
end
117107

118-
function invlink!!(
119-
::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
120-
)
108+
function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
121109
return settrans!!(
122110
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
123111
NoTransformation(),
124112
)
125113
end
126114

127-
function link(
128-
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
129-
)
130-
return link!!(t, deepcopy(vi), spl, model)
115+
function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
116+
return link!!(t, deepcopy(vi), model)
131117
end
132118

133-
function invlink(
134-
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
135-
)
136-
return invlink!!(t, deepcopy(vi), spl, model)
119+
function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
120+
return invlink!!(t, deepcopy(vi), model)
137121
end
138122

139-
function maybe_invlink_before_eval!!(
140-
vi::ThreadSafeVarInfo, context::AbstractContext, model::Model
141-
)
123+
function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model)
142124
# Defer to the wrapped `AbstractVarInfo` object.
143-
# NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)`
144-
# hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`.
145-
return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(
146-
vi.varinfo, context, model
147-
)
125+
# NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the
126+
# `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in
127+
# the `getlogp(vi)`.
128+
return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model)
148129
end
149130

150131
# `getindex`
@@ -182,8 +163,8 @@ function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName})
182163
return vector_getranges(vi.varinfo, vns)
183164
end
184165

185-
function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler)
186-
return set_retained_vns_del_by_spl!(vi.varinfo, spl)
166+
function set_retained_vns_del!(vi::ThreadSafeVarInfo)
167+
return set_retained_vns_del!(vi.varinfo)
187168
end
188169

189170
isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)

0 commit comments

Comments
 (0)