Skip to content

Commit d6d1761

Browse files
committed
Unify sample transformation and improve rand for measures
1 parent 5a79df8 commit d6d1761

11 files changed

+413
-83
lines changed

src/BAT.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ using Accessors: @set
6767

6868
import HeterogeneousComputing
6969
using HeterogeneousComputing: AbstractComputeUnit, CPUnit
70-
using HeterogeneousComputing: GenContext, get_rng, get_precision, get_compute_unit, get_gencontext
70+
using HeterogeneousComputing: GenContext, get_rng, get_precision, get_compute_unit, get_gencontext, allocate_array
7171

7272
import MeasureBase
7373
using MeasureBase: AbstractMeasure, DensityMeasure, Likelihood

src/algotypes/transform_algorithm.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ end
277277
struct SampleTransformation <: TransformAlgorithm end
278278

279279
function bat_transform_impl(f::Function, smpls::DensitySampleVector, ::SampleTransformation, context::BATContext)
280-
(result = broadcast_arbitrary_trafo(f, smpls), f_transform = f)
280+
(result = transform_samples(f, smpls), f_transform = f)
281281
end
282282

283283
function bat_transform_impl(shp::AbstractValueShape, smpls::DensitySampleVector, ::SampleTransformation, context::BATContext)

src/initvals/initvals.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ end
7373

7474
function bat_initval_impl(target::BATPushFwdMeasure, n::Integer, algorithm::InitFromTarget, context::BATContext)
7575
vs_orig = bat_initval_impl(target.orig, n, algorithm, context).result
76-
vs = BAT.broadcast_trafo(gettransform(target), vs_orig)
76+
vs = BAT.transform_samples(gettransform(target), vs_orig)
7777
(result = vs,)
7878
end
7979

@@ -164,7 +164,7 @@ end
164164

165165

166166
function apply_trafo_to_init(f_transform::Function, initalg::ExplicitInit)
167-
xs_tr = broadcast_trafo(f_transform, initalg.xs)
167+
xs_tr = transform_samples(f_transform, initalg.xs)
168168
ExplicitInit(xs_tr)
169169
end
170170

src/measures/bat_pushfwd_measure.jl

+9-3
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,15 @@ function checked_logdensityof(m::BATPushFwdMeasure{F,I,M,KeepRootMeasure}, v::An
135135
end
136136

137137

138-
Random.rand(rng::AbstractRNG, ::Type{T}, m::BATPushFwdMeasure) where {T<:Real} = m.f(rand(rng, T, m.origin))
139-
140-
Random.rand(rng::AbstractRNG, m::BATPushFwdMeasure) = m.f(rand(rng, m.origin))
138+
Random.rand(gen::GenContext, m::BATPushFwdMeasure) = m.f(rand(gen, m.origin))
139+
140+
function Base.rand(gen::GenContext, m::BATPwrMeasure{<:BATPushFwdMeasure})
141+
m_nonpwr, sz = m.parent, m.sz
142+
f = m_nonpwr.f
143+
m_origin = m_nonpwr.origin ^ sz
144+
X_origin = rand(gen, m_origin)
145+
return transform_samples(f, X_origin)
146+
end
141147

142148
supports_rand(m::BATPushFwdMeasure) = supports_rand(m.origin)
143149

src/measures/bat_pwr_measure.jl

+2-7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ function _cartidxs(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N}
3434
CartesianIndices(map(_dynamic, axs))
3535
end
3636

37+
3738
function Base.rand(gen::GenContext, m::BATPwrMeasure)
3839
cunit = get_compute_unit(gen)
3940
rng = get_rng(gen)
@@ -47,12 +48,6 @@ function Base.rand(gen::GenContext, m::BATPwrMeasure{<:BATDistMeasure})
4748
gen_adapt(gen, reshaped_X)
4849
end
4950

50-
function Base.rand(gen::GenContext, m::BATPwrMeasure{<:BATDistMeasure})
51-
X = rand(get_rng(gen), m.parent.dist, size(marginals(m))...)
52-
reshaped_X = _reshape_rand_n_output(X)
53-
gen_adapt(gen, reshaped_X)
54-
end
55-
5651
function Base.rand(gen::GenContext, m::BATPwrMeasure{<:DensitySampleMeasure})
5752
# Always generate R on CPU for now:
5853
R = rand(get_rng(gen), size(marginals(m))...)
@@ -75,7 +70,7 @@ MeasureBase.marginals(m::BATPwrMeasure) = Fill(_pwr_base(m), _pwr_size(m))
7570
@assert size(x) == _pwr_size(m)
7671
m_base = _pwr_base(m)
7772
sum(x) do x_i
78-
logdensity_def(m_base, x_i)
73+
logdensityof(m_base, x_i)
7974
end
8075
end
8176

src/measures/bat_weighted_measure.jl

+7
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ end
5050

5151
Base.rand(gen::GenContext, m::BATWeightedMeasure) = rand(gen, m.base)
5252

53+
function Base.rand(gen::GenContext, m::BATPwrMeasure{<:BATWeightedMeasure})
54+
m_nonpwr, sz = m.parent, m.sz
55+
m_origin = m_nonpwr.base ^ sz
56+
X_origin = rand(gen, m_origin)
57+
return X_origin
58+
end
59+
5360
supports_rand(m::BATWeightedMeasure) = supports_rand(m.origin)
5461

5562

src/transforms/distribution_transform.jl

+2-8
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,6 @@ struct DistributionTransform{
112112
end
113113

114114

115-
# ToDo: Unify with broadcast_trafo
116-
function broadcast_arbitrary_trafo(f_transform::DistributionTransform, smpls::DensitySampleVector)
117-
broadcast_trafo(f_transform, smpls)
118-
end
119-
120-
121115
function _distrafo_ctor_impl(target_dist::DT, source_dist::DF) where {DT<:ContinuousDistribution,DF<:ContinuousDistribution}
122116
@argcheck eff_totalndof(target_dist) == eff_totalndof(source_dist)
123117
DistributionTransform{DT,DF}(target_dist, source_dist)
@@ -177,7 +171,7 @@ function Base.Broadcast.broadcasted(
177171
f_transform::DistributionTransform,
178172
v_src::Union{ArrayOfSimilarVectors{<:Real},ShapedAsNTArray}
179173
)
180-
broadcast_trafo(f_transform, v_src)
174+
transform_samples(f_transform, v_src)
181175
end
182176

183177

@@ -192,7 +186,7 @@ function Base.Broadcast.broadcasted(
192186
f_transform::DistributionTransform,
193187
s_src::DensitySampleVector
194188
)
195-
broadcast_trafo(f_transform, s_src)
189+
transform_samples(f_transform, s_src)
196190
end
197191

198192

0 commit comments

Comments
 (0)