Skip to content

Commit ce94f19

Browse files
committed
Improve typing in transforms.jl (#232)
1 parent f8a2e11 commit ce94f19

File tree

2 files changed

+50
-41
lines changed

2 files changed

+50
-41
lines changed

src/core/util.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ macro VarName(ex::Union{Expr, Symbol})
2525
end
2626
end
2727

28-
invlogit(x) = 1.0 ./ (exp(-x) + 1.0)
28+
invlogit(x::Real) = one(x) / (one(x) + exp(-x))
2929

30-
logit(x) = log(x ./ (1.0 - x))
30+
logit(x::Real) = log(x / (one(x) - x))
3131

3232
function randcat(p::Vector{Float64}) # More stable, faster version of rand(Categorical)
3333
# if(any(p .< 0)) error("Negative probabilities not allowed"); end

src/samplers/support/transform.jl

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
> SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2727
=#
2828

29-
30-
#################### TransformDistribution ####################
29+
#############
30+
# a ≦ x ≦ b #
31+
#############
3132

3233
typealias TransformDistribution{T<:ContinuousUnivariateDistribution}
33-
Union{T, Truncated{T}}
34+
Union{T, Truncated{T}}
3435

35-
function link(d::TransformDistribution, x::Real)
36+
link(d::TransformDistribution, x::Real) = begin
3637
a, b = minimum(d), maximum(d)
3738
lowerbounded, upperbounded = isfinite(a), isfinite(b)
3839
if lowerbounded && upperbounded
@@ -46,7 +47,7 @@ function link(d::TransformDistribution, x::Real)
4647
end
4748
end
4849

49-
function invlink(d::TransformDistribution, x::Real)
50+
invlink(d::TransformDistribution, x::Real) = begin
5051
a, b = minimum(d), maximum(d)
5152
lowerbounded, upperbounded = isfinite(a), isfinite(b)
5253
if lowerbounded && upperbounded
@@ -76,12 +77,13 @@ Distributions.logpdf(d::TransformDistribution, x::Real, transform::Bool) = begin
7677
lp
7778
end
7879

79-
80-
#################### RealDistribution ####################
80+
###############
81+
# -∞ < x < -∞ #
82+
###############
8183

8284
typealias RealDistribution
83-
Union{Cauchy, Gumbel, Laplace, Logistic, NoncentralT, Normal,
84-
NormalCanon, TDist}
85+
Union{Cauchy, Gumbel, Laplace, Logistic,
86+
NoncentralT, Normal, NormalCanon, TDist}
8587

8688
link(d::RealDistribution, x::Real) = x
8789

@@ -90,7 +92,9 @@ invlink(d::RealDistribution, x::Real) = x
9092
Distributions.logpdf(d::RealDistribution, x::Real, transform::Bool) = logpdf(d, x)
9193

9294

93-
#################### PositiveDistribution ####################
95+
#########
96+
# 0 < x #
97+
#########
9498

9599
typealias PositiveDistribution
96100
Union{BetaPrime, Chi, Chisq, Erlang, Exponential, FDist, Frechet,
@@ -107,7 +111,9 @@ Distributions.logpdf(d::PositiveDistribution, x::Real, transform::Bool) = begin
107111
end
108112

109113

110-
#################### UnitDistribution ####################
114+
#############
115+
# 0 < x < 1 #
116+
#############
111117

112118
typealias UnitDistribution
113119
Union{Beta, KSOneSided, NoncentralBeta}
@@ -118,46 +124,45 @@ invlink(d::UnitDistribution, x::Real) = invlogit(x)
118124

119125
Distributions.logpdf(d::UnitDistribution, x::Real, transform::Bool) = begin
120126
lp = logpdf(d, x)
121-
transform ? lp + log(x * (1.0 - x)) : lp
127+
transform ? lp + log(x * (one(x) - x)) : lp
122128
end
123129

124-
################### SimplexDistribution ###################
130+
###########
131+
# ∑xᵢ = 1 #
132+
###########
125133

126134
typealias SimplexDistribution Union{Dirichlet}
127135

128-
function link(d::SimplexDistribution, x::Vector)
136+
link{T}(d::SimplexDistribution, x::Vector{T}) = begin
129137
K = length(x)
130-
T = typeof(x[1])
131138
z = Vector{T}(K-1)
132139
for k in 1:K-1
133-
z[k] = x[k] / (1 - sum(x[1:k-1]))
140+
z[k] = x[k] / (one(T) - sum(x[1:k-1]))
134141
end
135-
y = [logit(z[k]) - log(1 / (K-k)) for k in 1:K-1]
136-
push!(y, T(0))
142+
y = [logit(z[k]) - log(one(T) / (K-k)) for k in 1:K-1]
143+
push!(y, zero(T))
137144
end
138145

139-
function invlink(d::SimplexDistribution, y::Vector)
146+
invlink{T}(d::SimplexDistribution, y::Vector{T}) = begin
140147
K = length(y)
141-
T = typeof(y[1])
142-
z = [invlogit(y[k] + log(1 / (K - k))) for k in 1:K-1]
148+
z = [invlogit(y[k] + log(one(T) / (K - k))) for k in 1:K-1]
143149
x = Vector{T}(K)
144150
for k in 1:K-1
145-
x[k] = (1 - sum(x[1:k-1])) * z[k]
151+
x[k] = (one(T) - sum(x[1:k-1])) * z[k]
146152
end
147-
x[K] = 1 - sum(x[1:K-1])
153+
x[K] = one(T) - sum(x[1:K-1])
148154
x
149155
end
150156

151-
Distributions.logpdf(d::SimplexDistribution, x::Vector, transform::Bool) = begin
157+
Distributions.logpdf{T}(d::SimplexDistribution, x::Vector{T}, transform::Bool) = begin
152158
lp = logpdf(d, x)
153159
if transform
154160
K = length(x)
155-
T = typeof(x[1])
156161
z = Vector{T}(K-1)
157162
for k in 1:K-1
158-
z[k] = x[k] / (1 - sum(x[1:k-1]))
163+
z[k] = x[k] / (one(T) - sum(x[1:k-1]))
159164
end
160-
lp += sum([log(z[k]) + log(1 - z[k]) + log(1 - sum(x[1:k-1])) for k in 1:K-1])
165+
lp += sum([log(z[k]) + log(one(T) - z[k]) + log(one(T) - sum(x[1:k-1])) for k in 1:K-1])
161166
end
162167
lp
163168
end
@@ -166,50 +171,54 @@ Distributions.logpdf(d::Categorical, x::Int) = begin
166171
d.p[x] > 0.0 && insupport(d, x) ? log(d.p[x]) : eltype(d.p)(-Inf)
167172
end
168173

169-
############### PDMatDistribution ##############
174+
#####################
175+
# Positive definite #
176+
#####################
170177

171178
typealias PDMatDistribution Union{InverseWishart, Wishart}
172179

173-
function link{T}(d::PDMatDistribution, x::Array{T,2})
180+
link{T}(d::PDMatDistribution, x::Array{T,2}) = begin
174181
z = chol(x)'
175182
dim = size(z)
176183
for m in 1:dim[1]
177184
z[m, m] = log(z[m, m])
178185
end
179186
for m in 1:dim[1], n in m+1:dim[2]
180-
z[m, n] = 0
187+
z[m, n] = zero(T)
181188
end
182189
Array{T,2}(z)
183190
end
184191

185-
function invlink{T}(d::PDMatDistribution, z::Array{T,2})
192+
invlink{T}(d::PDMatDistribution, z::Array{T,2}) = begin
186193
dim = size(z)
187194
for m in 1:dim[1]
188195
z[m, m] = exp(z[m, m])
189196
end
190197
for m in 1:dim[1], n in m+1:dim[2]
191-
z[m, n] = 0
198+
z[m, n] = zero(T)
192199
end
193200
Array{T,2}(z * z')
194201
end
195202

196-
Distributions.logpdf(d::PDMatDistribution, x::Array, transform::Bool) = begin
203+
Distributions.logpdf{T}(d::PDMatDistribution, x::Array{T,2}, transform::Bool) = begin
197204
lp = logpdf(d, x)
198205
if transform && isfinite(lp)
199206
U = chol(x)
200207
n = dim(d)
201208
for i in 1:n
202-
lp += (n - i + 2) * log(U[i, i])
209+
lp += (n - i + one(T) + one(T)) * log(U[i, i])
203210
end
204-
lp += n * log(2)
211+
lp += n * log(one(T) + one(T))
205212
end
206213
lp
207214
end
208215

209-
################## Callback functions ##################
216+
#############
217+
# Callbacks #
218+
#############
210219

211-
link(d::Distribution, x) = x
220+
link(d::Distribution, x::Any) = x
212221

213-
invlink(d::Distribution, x) = x
222+
invlink(d::Distribution, x::Any) = x
214223

215-
Distributions.logpdf(d::Distribution, x, transform::Bool) = logpdf(d, x)
224+
Distributions.logpdf(d::Distribution, x::Any, transform::Bool) = logpdf(d, x)

0 commit comments

Comments
 (0)