Skip to content

Commit 50f8a55

Browse files
committed
Merge branch 'stoch_grad_hmc' of sivapvarma/AdvancedHMC.jl into stoch_grad_hmc
2 parents 0b60062 + 944b9f5 commit 50f8a55

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

src/trajectory.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,93 @@ end
747747

748748

749749

750+
###
751+
### Stochastic Gradient Langevin Dynamics sampler.
752+
###
753+
"""
754+
Stochastic Gradient Langevin Dynamics with fixed number of steps.
755+
"""
756+
mutable struct SGLD{
757+
I<:AbstractIntegrator,
758+
F<:AbstractFloat
759+
} <: AbstractTrajectory{I}
760+
integrator :: I
761+
n_steps :: Int # number of samples
762+
ϵ :: F # constant scale factor of the learning rate
763+
i :: Int # iteration counter
764+
γ :: F # scaling constant
765+
end
766+
767+
function transition(
768+
rng::AbstractRNG,
769+
τ::SGLD,
770+
h::Hamiltonian,
771+
z::PhasePoint
772+
) where {T<:Real}
773+
# z′ = step(rng, τ.integrator, h, z, τ.n_steps)
774+
DEBUG && @debug "compute current step size..."
775+
# γ = .35
776+
τ.i += 1
777+
ϵ_t = τ.ϵ / τ.i ^ τ.γ # NOTE: Choose γ=.55 in paper
778+
779+
DEBUG && @debug "recording old variables..."
780+
θ = z.θ
781+
grad = -z.ℓπ.gradient
782+
783+
DEBUG && @debug "update latent variables..."
784+
θ .+= ϵ_t .* grad ./ 2 .+ rand.(Normal.(zeros(length(θ)), sqrt(ϵ_t)))
785+
786+
# no M-H step
787+
z = PhasePoint(h, θ, -z.r)
788+
stat = (
789+
step_size=τ.integrator.ϵ,
790+
n_steps=τ.n_steps,
791+
log_density=z.ℓπ.value,
792+
hamiltonian_energy=energy(z),
793+
)
794+
return Transition(z, stat)
795+
end
796+
797+
##
798+
## Stochastic Gradient Hamilton Samplers
799+
##
800+
801+
###
802+
### Stochastic Gradient Hamiltonian Monte Carlo sampler.
803+
###
804+
"""
805+
Stochastic Gradient HMC with fixed number of steps.
806+
"""
807+
struct SGHMC{
808+
I<:AbstractIntegrator,
809+
F<:AbstractFloat
810+
} <: AbstractTrajectory{I}
811+
integrator :: I
812+
n_steps :: Int # number of samples
813+
η :: F # learning rate
814+
α :: F # momentum decay
815+
end
816+
817+
function transition(
818+
rng::AbstractRNG,
819+
τ::SGHMC,
820+
h::Hamiltonian,
821+
z::PhasePoint
822+
) where {T<:Real}
823+
z′ = step(rng, τ.integrator, h, z, τ.n_steps)
824+
# no M-H step
825+
z = PhasePoint(z′.θ, z′.r, z′.ℓπ, z′.ℓκ)
826+
stat = (
827+
step_size=τ.integrator.ϵ,
828+
n_steps=τ.n_steps,
829+
log_density=z.ℓπ.value,
830+
hamiltonian_energy=energy(z),
831+
)
832+
return Transition(z, stat)
833+
end
834+
835+
836+
750837
###
751838
### Stochastic Gradient Langevin Dynamics sampler.
752839
###

0 commit comments

Comments
 (0)