Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ function run_benchmarks!(𝓂::ℳ, SUITE::BenchmarkGroup)

SUITE[𝓂.model_name]["qme"] = BenchmarkGroup()

sol, qme_sol, solved = calculate_first_order_solution(∇₁; T = 𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :schur))
sol, qme_sol, solved = calculate_first_order_solution(∇₁; T = 𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :schur), 𝒬ℂ = 𝓂.caches.qme_caches)

clear_solution_caches!(𝓂, :first_order)

SUITE[𝓂.model_name]["qme"]["schur"] = @benchmarkable calculate_first_order_solution($∇₁; T = $𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :schur)) setup = clear_solution_caches!($𝓂, :first_order)
SUITE[𝓂.model_name]["qme"]["schur"] = @benchmarkable calculate_first_order_solution($∇₁; T = $𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :schur), 𝒬ℂ = $𝓂.caches.qme_caches) setup = clear_solution_caches!($𝓂, :first_order)

SUITE[𝓂.model_name]["qme"]["doubling"] = @benchmarkable calculate_first_order_solution($∇₁; T = $𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :doubling)) setup = clear_solution_caches!($𝓂, :first_order)
SUITE[𝓂.model_name]["qme"]["doubling"] = @benchmarkable calculate_first_order_solution($∇₁; T = $𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :doubling), 𝒬ℂ = $𝓂.caches.qme_caches) setup = clear_solution_caches!($𝓂, :first_order)


A = @views sol[:, 1:𝓂.timings.nPast_not_future_and_mixed] * ℒ.diagm(ones(𝓂.timings.nVars))[𝓂.timings.past_not_future_and_mixed_idx,:]
Expand Down
42 changes: 24 additions & 18 deletions src/MacroModelling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ function clear_solution_caches!(𝓂::ℳ, algorithm::Symbol)
𝓂.solution.perturbation.qme_solution = zeros(0,0)
𝓂.solution.perturbation.second_order_solution = spzeros(0,0)
𝓂.solution.perturbation.third_order_solution = spzeros(0,0)
𝓂.caches.qme_caches = QME_caches()

return nothing
end
Expand Down Expand Up @@ -5042,10 +5043,11 @@ function calculate_second_order_stochastic_steady_state(parameters::Vector{M},

# @timeit_debug timer "Calculate first order solution" begin

𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = 𝓂.timings,
𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = 𝓂.timings,
opts = opts,
initial_guess = 𝓂.solution.perturbation.qme_solution)
initial_guess = 𝓂.solution.perturbation.qme_solution,
𝒬ℂ = 𝓂.caches.qme_caches)

if solved 𝓂.solution.perturbation.qme_solution = qme_sol end

Expand Down Expand Up @@ -5369,10 +5371,11 @@ function calculate_third_order_stochastic_steady_state( parameters::Vector{M},

∇₁ = calculate_jacobian(parameters, SS_and_pars, 𝓂)# |> Matrix

𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = 𝓂.timings,
𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = 𝓂.timings,
opts = opts,
initial_guess = 𝓂.solution.perturbation.qme_solution)
initial_guess = 𝓂.solution.perturbation.qme_solution,
𝒬ℂ = 𝓂.caches.qme_caches)

if solved 𝓂.solution.perturbation.qme_solution = qme_sol end

Expand Down Expand Up @@ -5754,10 +5757,11 @@ function solve!(𝓂::ℳ;

# @timeit_debug timer "Calculate first order solution" begin

S₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = 𝓂.timings,
S₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = 𝓂.timings,
opts = opts,
initial_guess = 𝓂.solution.perturbation.qme_solution)
initial_guess = 𝓂.solution.perturbation.qme_solution,
𝒬ℂ = 𝓂.caches.qme_caches)

if solved 𝓂.solution.perturbation.qme_solution = qme_sol end

Expand All @@ -5776,10 +5780,11 @@ function solve!(𝓂::ℳ;

∇̂₁ = calculate_jacobian(𝓂.parameter_values, SS_and_pars, 𝓂)# |> Matrix

Ŝ₁, qme_sol, solved = calculate_first_order_solution(∇̂₁;
T = 𝓂.timings,
opts = opts,
initial_guess = 𝓂.solution.perturbation.qme_solution)
Ŝ₁, qme_sol, solved = calculate_first_order_solution(∇̂₁;
T = 𝓂.timings,
opts = opts,
initial_guess = 𝓂.solution.perturbation.qme_solution,
𝒬ℂ = 𝓂.caches.qme_caches)

if solved 𝓂.solution.perturbation.qme_solution = qme_sol end

Expand Down Expand Up @@ -8358,11 +8363,12 @@ function get_relevant_steady_state_and_state_update(::Val{:first_order},

∇₁ = calculate_jacobian(parameter_values, SS_and_pars, 𝓂) # , timer = timer)# |> Matrix

𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = TT,
# timer = timer,
initial_guess = 𝓂.solution.perturbation.qme_solution,
opts = opts)
𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = TT,
# timer = timer,
initial_guess = 𝓂.solution.perturbation.qme_solution,
opts = opts,
𝒬ℂ = 𝓂.caches.qme_caches)

if solved 𝓂.solution.perturbation.qme_solution = qme_sol end

Expand Down
112 changes: 72 additions & 40 deletions src/algorithms/quadratic_matrix_equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@

@stable default_mode = "disable" begin

function solve_quadratic_matrix_equation(A::AbstractMatrix{R},
B::AbstractMatrix{R},
C::AbstractMatrix{R},
T::timings;
function solve_quadratic_matrix_equation(A::AbstractMatrix{R},
B::AbstractMatrix{R},
C::AbstractMatrix{R},
T::timings;
initial_guess::AbstractMatrix{R} = zeros(0,0),
quadratic_matrix_equation_algorithm::Symbol = :schur,
tol::AbstractFloat = 1e-14,
acceptance_tol::AbstractFloat = 1e-8,
verbose::Bool = false) where R <: Real
verbose::Bool = false,
𝒬ℂ::qme_caches = QME_caches()) where R <: Real

if length(initial_guess) > 0
X = initial_guess
Expand All @@ -38,13 +39,23 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R},
end
end

sol, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C,
Val(quadratic_matrix_equation_algorithm),
T;
if quadratic_matrix_equation_algorithm == :doubling
sol, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C,
Val(:doubling),
T;
initial_guess = initial_guess,
tol = tol,
verbose = verbose,
𝒬ℂ = 𝒬ℂ)
else
sol, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C,
Val(quadratic_matrix_equation_algorithm),
T;
initial_guess = initial_guess,
tol = tol,
# timer = timer,
verbose = verbose)
end

if verbose println("Quadratic matrix equation solver: $quadratic_matrix_equation_algorithm - converged: $(reached_tol < acceptance_tol) in $iterations iterations to tolerance: $reached_tol") end

Expand All @@ -60,13 +71,14 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R},

if verbose println("Quadratic matrix equation solver: schur - converged: $(reached_tol < acceptance_tol) in $iterations iterations to tolerance: $reached_tol") end
else quadratic_matrix_equation_algorithm ≠ :doubling
sol, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C,
Val(:doubling),
T;
sol, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C,
Val(:doubling),
T;
initial_guess = initial_guess,
tol = tol,
# timer = timer,
verbose = verbose)
verbose = verbose,
𝒬ℂ = 𝒬ℂ)

if verbose println("Quadratic matrix equation solver: doubling - converged: $(reached_tol < acceptance_tol) in $iterations iterations to tolerance: $reached_tol") end
end
Expand Down Expand Up @@ -206,16 +218,17 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R},
end


function solve_quadratic_matrix_equation(A::AbstractMatrix{R},
B::AbstractMatrix{R},
C::AbstractMatrix{R},
::Val{:doubling},
T::timings;
function solve_quadratic_matrix_equation(A::AbstractMatrix{R},
B::AbstractMatrix{R},
C::AbstractMatrix{R},
::Val{:doubling},
T::timings;
initial_guess::AbstractMatrix{R} = zeros(0,0),
tol::AbstractFloat = 1e-14,
# timer::TimerOutput = TimerOutput(),
verbose::Bool = false,
max_iter::Int = 100)::Tuple{Matrix{R}, Int64, R} where R <: AbstractFloat
max_iter::Int = 100,
𝒬ℂ::qme_caches = QME_caches())::Tuple{Matrix{R}, Int64, R} where R <: AbstractFloat
# Johannes Huber, Alexander Meyer-Gohde, Johanna Saecker (2024). Solving Linear DSGE Models with Structure Preserving Doubling Methods.
# https://www.imfs-frankfurt.de/forschung/imfs-working-papers/details.html?tx_mmpublications_publicationsdetail%5Bcontroller%5D=Publication&tx_mmpublications_publicationsdetail%5Bpublication%5D=461&cHash=f53244e0345a27419a9d40a3af98c02f
# https://arxiv.org/abs/2212.09491
Expand Down Expand Up @@ -282,19 +295,26 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R},
# end # timeit_debug
# @timeit_debug timer "Invert EI" begin

fEI = ℒ.lu!(temp1, check = false)

if !ℒ.issuccess(fEI)
return A, iter, 1.0
alg = issparse(temp1) ? 𝒮.UMFPACKFactorization() : 𝒮.LUFactorization()
if !(typeof(𝒬ℂ.EI_cache.alg) === typeof(alg)) ||
size(𝒬ℂ.EI_cache.A) != size(temp1) ||
length(𝒬ℂ.EI_cache.b) != size(E,1)
prob = 𝒮.LinearProblem(copy(temp1), zeros(R, size(temp1,1)), alg)
𝒬ℂ.EI_cache = 𝒮.init(prob, alg)
else
copy!(𝒬ℂ.EI_cache.A, temp1)
𝒬ℂ.EI_cache.isfresh = true
end

for j in 1:size(E,2)
copyto!(𝒬ℂ.EI_cache.b, view(E,:,j))
𝒮.solve!(𝒬ℂ.EI_cache)
temp3[:,j] .= 𝒬ℂ.EI_cache.u
end
ℒ.mul!(E_new, E, temp3)
# end # timeit_debug
# @timeit_debug timer "Compute E" begin

# Compute E = E * EI * E
ℒ.ldiv!(temp3, fEI, E)
ℒ.mul!(E_new, E, temp3)
# E_new = E / fEI * E
# E_new = E / EI * E

# end # timeit_debug
# @timeit_debug timer "Compute FI" begin
Expand All @@ -309,26 +329,34 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R},
# end # timeit_debug
# @timeit_debug timer "Invert FI" begin

fFI = ℒ.lu!(temp2, check = false)

if !ℒ.issuccess(fFI)
return A, iter, 1.0
fFI_cache_alg = issparse(temp2) ? 𝒮.UMFPACKFactorization() : 𝒮.LUFactorization()
if !(typeof(𝒬ℂ.FI_cache.alg) === typeof(fFI_cache_alg)) ||
size(𝒬ℂ.FI_cache.A) != size(temp2) ||
length(𝒬ℂ.FI_cache.b) != size(F,1)
prob = 𝒮.LinearProblem(copy(temp2), zeros(R, size(temp2,1)), fFI_cache_alg)
𝒬ℂ.FI_cache = 𝒮.init(prob, fFI_cache_alg)
else
copy!(𝒬ℂ.FI_cache.A, temp2)
𝒬ℂ.FI_cache.isfresh = true
end
for j in 1:size(F,2)
copyto!(𝒬ℂ.FI_cache.b, view(F,:,j))
𝒮.solve!(𝒬ℂ.FI_cache)
temp3[:,j] .= 𝒬ℂ.FI_cache.u
end

# end # timeit_debug
# @timeit_debug timer "Compute F" begin

# Compute F = F * FI * F
ℒ.ldiv!(temp3, fFI, F)
ℒ.mul!(F_new, F, temp3)
# F_new = F / fFI * F
# F_new = F / FI * F

# end # timeit_debug
# @timeit_debug timer "Compute X_new" begin

# Compute X_new = X + F * FI * X * E
ℒ.mul!(temp3, X, E)
ℒ.ldiv!(fFI, temp3)
for j in 1:size(temp3,2)
copyto!(𝒬ℂ.FI_cache.b, view(temp3,:,j))
𝒮.solve!(𝒬ℂ.FI_cache)
temp3[:,j] .= 𝒬ℂ.FI_cache.u
end
ℒ.mul!(X_new, F, temp3)
# X_new = F / fFI * X * E
if i > 5 || guess_provided
Expand All @@ -343,7 +371,11 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R},

# Compute Y_new = Y + E * EI * Y * F
ℒ.mul!(X, Y, F) # use X as temporary storage
ℒ.ldiv!(fEI, X)
for j in 1:size(X,2)
copyto!(𝒬ℂ.EI_cache.b, view(X,:,j))
𝒮.solve!(𝒬ℂ.EI_cache)
X[:,j] .= 𝒬ℂ.EI_cache.u
end
ℒ.mul!(Y_new, E, X)
# Y_new = E / fEI * Y * F
if i > 5 || guess_provided
Expand Down
9 changes: 5 additions & 4 deletions src/filter/inversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3457,10 +3457,11 @@ function filter_data_with_model(𝓂::ℳ,

∇₁ = calculate_jacobian(𝓂.parameter_values, SS_and_pars, 𝓂)# |> Matrix

𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = T,
initial_guess = 𝓂.solution.perturbation.qme_solution,
opts = opts)
𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = T,
initial_guess = 𝓂.solution.perturbation.qme_solution,
opts = opts,
𝒬ℂ = 𝓂.caches.qme_caches)

if solved 𝓂.solution.perturbation.qme_solution = qme_sol end

Expand Down
2 changes: 1 addition & 1 deletion src/filter/kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ function filter_and_smooth(𝓂::ℳ,

∇₁ = calculate_jacobian(parameters, SS_and_pars, 𝓂)# |> Matrix

sol, qme_sol, solved = calculate_first_order_solution(∇₁; T = 𝓂.timings, opts = opts)
sol, qme_sol, solved = calculate_first_order_solution(∇₁; T = 𝓂.timings, opts = opts, 𝒬ℂ = 𝓂.caches.qme_caches)

if solved 𝓂.solution.perturbation.qme_solution = qme_sol end

Expand Down
Loading
Loading