Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fit2: fast version of fitting algorithm #14

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/BetaRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,31 @@ function StatsAPI.score(b::BetaRegressionModel)
return ∂θ
end


function score2!(b::BetaRegressionModel, ∂θ)
X = modelmatrix(b)
y = response(b)
η = linearpredictor(b)
link = Link(b)
ϕ, dϕ, _ = precisioninverselink(b)
ψϕ = digamma(ϕ)
∂θ .= zero(eltype(X))
Tr = copy(η)
@inbounds for i in eachindex(y, η)
yᵢ = y[i]
μᵢ, omμᵢ, dμdη = inverselink(link, η[i])
ψp = digamma(ϕ * μᵢ)
ψq = digamma(ϕ * omμᵢ)
Δ = logit(yᵢ) - ψp + ψq # logit(yᵢ) - 𝔼(logit(yᵢ))
z = log1p(-yᵢ) - ψq + ψϕ # log(1 - yᵢ) - 𝔼(log(1 - yᵢ))
∂θ[end] += fma(μᵢ, Δ, z)
Tr[i] = ϕ * Δ * dμdη
end
mul!(view(∂θ, 1:size(X, 2)), X', Tr)
∂θ[end] *= dϕ
return ∂θ
end

# Square root of the diagonal of the weight matrix, W for expected information (pg 7),
# Q for observed information (pg 10). `p = μ * ϕ` and `q = (1 - μ) * ϕ` are the beta
# distribution parameters in the typical parameterization, `ψ′_` is `trigamma(_)`.
Expand Down Expand Up @@ -537,6 +562,44 @@ function 🐟(b::BetaRegressionModel, expected::Bool, inverse::Bool)
return Symmetric(K)
end


function 🐟!(b::BetaRegressionModel, K::AbstractMatrix, w, expected::Bool)
X = modelmatrix(b)
k = length(params(b))
η = linearpredictor(b)
y = response(b)
link = Link(b)
ϕ, dϕ, _ = precisioninverselink(b)
ψ′ϕ = trigamma(ϕ)
Tc = similar(η)
Tc .= dϕ
γ = zero(ϕ)

for i in eachindex(y, η, w)
ηᵢ = η[i]
μᵢ, omμᵢ, dμdη = inverselink(link, ηᵢ)
p = μᵢ * ϕ
q = omμᵢ * ϕ
ψ′p = trigamma(p)
ψ′q = trigamma(q)
w[i] = weightdiag(link, p, q, ψ′p, ψ′q, ϕ, y[i], ηᵢ, dμdη, expected)
Tc[i] *= (ψ′p * p - ψ′q * q) * dμdη
γ += ψ′p * μᵢ^2 + ψ′q * omμᵢ^2 - ψ′ϕ
end
γ *= dϕ^2
Xᵀ = copy(adjoint(X))
# update Kβϕ
view(K, 1:(k - 1), k) .= Xᵀ * Tc

Xᵀ .*= w'
# update Kββ
view(K, 1:(k - 1), 1:(k - 1)) .= Symmetric(syrk('U', 'N', ϕ, Xᵀ))

# update Kϕϕ
K[k, k] = γ
return K
end

"""
informationmatrix(model::BetaRegressionModel; expected=true)

Expand Down Expand Up @@ -599,6 +662,29 @@ function StatsAPI.fit!(b::BetaRegressionModel{T}; maxiter=100, atol=sqrt(eps(T))
throw(ConvergenceException(maxiter))
end

function fit2!(b::BetaRegressionModel{T}; maxiter=100, atol=sqrt(eps(T)),
rtol=Base.rtoldefault(T)) where {T}
initialize!(b)
θ = params(b)
z = zero(θ)
U = zero(θ)
p = length(θ)
K = zeros(T, p, p)
wwrk = similar(b.y)
# η = linearpredictor(b)
for iter in 1:maxiter
score2!(b, U)
checkfinite(U, iter)
isapprox(U, z; atol, rtol) && return b # converged!
🐟!(b, K, wwrk, true)
checkfinite(K, iter)
θ .+= ldiv!(U, cholesky!(Symmetric(K)), U)
θ[end] = max(θ[end], eps(eltype(θ)))
linearpredictor!(b)
end
throw(ConvergenceException(maxiter))
end

"""
fit(BetaRegressionModel, formula, data, link=LogitLink(), precisionlink=IdentityLink();
kwargs...)
Expand Down