diff --git a/src/trajectory.jl b/src/trajectory.jl index 066fcc8d..2e73a192 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -221,7 +221,7 @@ function accept_phasepoint!(z::T, z′::T, is_accept) where {T<:PhasePoint{<:Abs return z′ end -### Use end-point from trajecory as proposal +### Use end-point from trajecory as proposal samplecand(rng, τ::StaticTrajectory{EndPointTS}, h, z) = step(τ.integrator, h, z, τ.n_steps) @@ -706,3 +706,108 @@ function mh_accept_ratio( accept = rand(rng, T, length(Horiginal)) .< α return accept, α end + +## +## Stochastic Gradient Hamilton Samplers +## + +### +### Stochastic Gradient Hamiltonian Monte Carlo sampler. +### +""" +Stochastic Gradient HMC with fixed number of steps. +""" +struct SGHMC{ + I<:AbstractIntegrator, + F<:AbstractFloat +} <: AbstractTrajectory{I} + integrator :: I + n_steps :: Int # number of samples + batch_size :: Int # no of data points in minibatch for gradient estimate + η :: F # learning rate + α :: F # momentum decay +end + +function transition( + rng::AbstractRNG, + τ::SGHMC, + h::Hamiltonian, + z::PhasePoint +) where {T<:Real} + # z′ = step(rng, τ.integrator, h, z, τ.n_steps) + + m, η, α, D = τ.n_steps, τ.η, τ.α, τ.batch_size + + @unpack θ, r = z + + for i=1:m + # ToDo: how to compute stochastic gradient + stoch_grad = gradient(h, D) + + # update position + θ .+= r + + # update momentum + r .= (1 - α) .* r .+ η .* stoch_grad .+ rand.(Normal.(zeros(length(θ)), sqrt(2 * η * α))) + end + + # no M-H step + z = PhasePoint(θ, r, z′.ℓπ, z′.ℓκ) + stat = ( + step_size=τ.integrator.ϵ, + n_steps=τ.n_steps, + log_density=z.ℓπ.value, + hamiltonian_energy=energy(z), + ) + return Transition(z, stat) +end + + + +### +### Stochastic Gradient Langevin Dynamics sampler. +### +""" +Stochastic Gradient Langevin Dynamics with fixed number of steps. +""" +mutable struct SGLD{ + I<:AbstractIntegrator, + F<:AbstractFloat +} <: AbstractTrajectory{I} + integrator :: I + n_steps :: Int # number of samples + ϵ :: F # constant scale factor of the learning rate + i :: Int # iteration counter + γ :: F # scaling constant +end + +function transition( + rng::AbstractRNG, + τ::SGLD, + h::Hamiltonian, + z::PhasePoint +) where {T<:Real} + # z′ = step(rng, τ.integrator, h, z, τ.n_steps) + DEBUG && @debug "compute current step size..." + # γ = .35 + τ.i += 1 + ϵ_t = τ.ϵ / τ.i ^ τ.γ # NOTE: Choose γ=.55 in paper + + DEBUG && @debug "recording old variables..." + θ = z.θ + # ToDo: how to get stochastic gradient + grad = -z.ℓπ.gradient + + DEBUG && @debug "update latent variables..." + θ .+= ϵ_t .* grad ./ 2 .+ rand.(Normal.(zeros(length(θ)), sqrt(ϵ_t))) + + # no M-H step + z = PhasePoint(h, θ, -z.r) + stat = ( + step_size=τ.integrator.ϵ, + n_steps=τ.n_steps, + log_density=z.ℓπ.value, + hamiltonian_energy=energy(z), + ) + return Transition(z, stat) +end