diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 78091152b..352fb258d 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -43,13 +43,13 @@ function run_benchmarks!(𝓂::â„ģ, SUITE::BenchmarkGroup) SUITE[𝓂.model_name]["qme"] = BenchmarkGroup() - sol, qme_sol, solved = calculate_first_order_solution(∇₁; T = 𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :schur)) + sol, qme_sol, solved = calculate_first_order_solution(∇₁; T = 𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :schur), 𝒎ℂ = 𝓂.caches.qme_caches) clear_solution_caches!(𝓂, :first_order) - SUITE[𝓂.model_name]["qme"]["schur"] = @benchmarkable calculate_first_order_solution($∇₁; T = $𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :schur)) setup = clear_solution_caches!($𝓂, :first_order) + SUITE[𝓂.model_name]["qme"]["schur"] = @benchmarkable calculate_first_order_solution($∇₁; T = $𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :schur), 𝒎ℂ = $𝓂.caches.qme_caches) setup = clear_solution_caches!($𝓂, :first_order) - SUITE[𝓂.model_name]["qme"]["doubling"] = @benchmarkable calculate_first_order_solution($∇₁; T = $𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :doubling)) setup = clear_solution_caches!($𝓂, :first_order) + SUITE[𝓂.model_name]["qme"]["doubling"] = @benchmarkable calculate_first_order_solution($∇₁; T = $𝓂.timings, opts = merge_calculation_options(quadratic_matrix_equation_algorithm = :doubling), 𝒎ℂ = $𝓂.caches.qme_caches) setup = clear_solution_caches!($𝓂, :first_order) A = @views sol[:, 1:𝓂.timings.nPast_not_future_and_mixed] * ℒ.diagm(ones(𝓂.timings.nVars))[𝓂.timings.past_not_future_and_mixed_idx,:] diff --git a/src/MacroModelling.jl b/src/MacroModelling.jl index 7d542fef1..e55f37e2e 100644 --- a/src/MacroModelling.jl +++ b/src/MacroModelling.jl @@ -692,6 +692,7 @@ function clear_solution_caches!(𝓂::â„ģ, algorithm::Symbol) 𝓂.solution.perturbation.qme_solution = zeros(0,0) 𝓂.solution.perturbation.second_order_solution = spzeros(0,0) 𝓂.solution.perturbation.third_order_solution = spzeros(0,0) + 𝓂.caches.qme_caches = QME_caches() return nothing end @@ -5042,10 +5043,11 @@ function calculate_second_order_stochastic_steady_state(parameters::Vector{M}, # @timeit_debug timer "Calculate first order solution" begin - 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; - T = 𝓂.timings, + 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; + T = 𝓂.timings, opts = opts, - initial_guess = 𝓂.solution.perturbation.qme_solution) + initial_guess = 𝓂.solution.perturbation.qme_solution, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end @@ -5369,10 +5371,11 @@ function calculate_third_order_stochastic_steady_state( parameters::Vector{M}, ∇₁ = calculate_jacobian(parameters, SS_and_pars, 𝓂)# |> Matrix - 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; - T = 𝓂.timings, + 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; + T = 𝓂.timings, opts = opts, - initial_guess = 𝓂.solution.perturbation.qme_solution) + initial_guess = 𝓂.solution.perturbation.qme_solution, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end @@ -5754,10 +5757,11 @@ function solve!(𝓂::â„ģ; # @timeit_debug timer "Calculate first order solution" begin - S₁, qme_sol, solved = calculate_first_order_solution(∇₁; - T = 𝓂.timings, + S₁, qme_sol, solved = calculate_first_order_solution(∇₁; + T = 𝓂.timings, opts = opts, - initial_guess = 𝓂.solution.perturbation.qme_solution) + initial_guess = 𝓂.solution.perturbation.qme_solution, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end @@ -5776,10 +5780,11 @@ function solve!(𝓂::â„ģ; âˆ‡Ė‚â‚ = calculate_jacobian(𝓂.parameter_values, SS_and_pars, 𝓂)# |> Matrix - SĖ‚â‚, qme_sol, solved = calculate_first_order_solution(âˆ‡Ė‚â‚; - T = 𝓂.timings, - opts = opts, - initial_guess = 𝓂.solution.perturbation.qme_solution) + SĖ‚â‚, qme_sol, solved = calculate_first_order_solution(âˆ‡Ė‚â‚; + T = 𝓂.timings, + opts = opts, + initial_guess = 𝓂.solution.perturbation.qme_solution, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end @@ -8358,11 +8363,12 @@ function get_relevant_steady_state_and_state_update(::Val{:first_order}, ∇₁ = calculate_jacobian(parameter_values, SS_and_pars, 𝓂) # , timer = timer)# |> Matrix - 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; - T = TT, - # timer = timer, - initial_guess = 𝓂.solution.perturbation.qme_solution, - opts = opts) + 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; + T = TT, + # timer = timer, + initial_guess = 𝓂.solution.perturbation.qme_solution, + opts = opts, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end diff --git a/src/algorithms/quadratic_matrix_equation.jl b/src/algorithms/quadratic_matrix_equation.jl index a0c357dfc..d58ef340e 100644 --- a/src/algorithms/quadratic_matrix_equation.jl +++ b/src/algorithms/quadratic_matrix_equation.jl @@ -8,15 +8,16 @@ @stable default_mode = "disable" begin -function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, - B::AbstractMatrix{R}, - C::AbstractMatrix{R}, - T::timings; +function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, + B::AbstractMatrix{R}, + C::AbstractMatrix{R}, + T::timings; initial_guess::AbstractMatrix{R} = zeros(0,0), quadratic_matrix_equation_algorithm::Symbol = :schur, tol::AbstractFloat = 1e-14, acceptance_tol::AbstractFloat = 1e-8, - verbose::Bool = false) where R <: Real + verbose::Bool = false, + 𝒎ℂ::qme_caches = QME_caches()) where R <: Real if length(initial_guess) > 0 X = initial_guess @@ -38,13 +39,23 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, end end - sol, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C, - Val(quadratic_matrix_equation_algorithm), - T; + if quadratic_matrix_equation_algorithm == :doubling + sol, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C, + Val(:doubling), + T; + initial_guess = initial_guess, + tol = tol, + verbose = verbose, + 𝒎ℂ = 𝒎ℂ) + else + sol, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C, + Val(quadratic_matrix_equation_algorithm), + T; initial_guess = initial_guess, tol = tol, # timer = timer, verbose = verbose) + end if verbose println("Quadratic matrix equation solver: $quadratic_matrix_equation_algorithm - converged: $(reached_tol < acceptance_tol) in $iterations iterations to tolerance: $reached_tol") end @@ -60,13 +71,14 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, if verbose println("Quadratic matrix equation solver: schur - converged: $(reached_tol < acceptance_tol) in $iterations iterations to tolerance: $reached_tol") end else quadratic_matrix_equation_algorithm ≠ :doubling - sol, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C, - Val(:doubling), - T; + sol, iterations, reached_tol = solve_quadratic_matrix_equation(A, B, C, + Val(:doubling), + T; initial_guess = initial_guess, tol = tol, # timer = timer, - verbose = verbose) + verbose = verbose, + 𝒎ℂ = 𝒎ℂ) if verbose println("Quadratic matrix equation solver: doubling - converged: $(reached_tol < acceptance_tol) in $iterations iterations to tolerance: $reached_tol") end end @@ -206,16 +218,17 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, end -function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, - B::AbstractMatrix{R}, - C::AbstractMatrix{R}, - ::Val{:doubling}, - T::timings; +function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, + B::AbstractMatrix{R}, + C::AbstractMatrix{R}, + ::Val{:doubling}, + T::timings; initial_guess::AbstractMatrix{R} = zeros(0,0), tol::AbstractFloat = 1e-14, # timer::TimerOutput = TimerOutput(), verbose::Bool = false, - max_iter::Int = 100)::Tuple{Matrix{R}, Int64, R} where R <: AbstractFloat + max_iter::Int = 100, + 𝒎ℂ::qme_caches = QME_caches())::Tuple{Matrix{R}, Int64, R} where R <: AbstractFloat # Johannes Huber, Alexander Meyer-Gohde, Johanna Saecker (2024). Solving Linear DSGE Models with Structure Preserving Doubling Methods. # https://www.imfs-frankfurt.de/forschung/imfs-working-papers/details.html?tx_mmpublications_publicationsdetail%5Bcontroller%5D=Publication&tx_mmpublications_publicationsdetail%5Bpublication%5D=461&cHash=f53244e0345a27419a9d40a3af98c02f # https://arxiv.org/abs/2212.09491 @@ -282,19 +295,26 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, # end # timeit_debug # @timeit_debug timer "Invert EI" begin - fEI = ℒ.lu!(temp1, check = false) - - if !ℒ.issuccess(fEI) - return A, iter, 1.0 + alg = issparse(temp1) ? ð’Ū.UMFPACKFactorization() : ð’Ū.LUFactorization() + if !(typeof(𝒎ℂ.EI_cache.alg) === typeof(alg)) || + size(𝒎ℂ.EI_cache.A) != size(temp1) || + length(𝒎ℂ.EI_cache.b) != size(E,1) + prob = ð’Ū.LinearProblem(copy(temp1), zeros(R, size(temp1,1)), alg) + 𝒎ℂ.EI_cache = ð’Ū.init(prob, alg) + else + copy!(𝒎ℂ.EI_cache.A, temp1) + 𝒎ℂ.EI_cache.isfresh = true end - + for j in 1:size(E,2) + copyto!(𝒎ℂ.EI_cache.b, view(E,:,j)) + ð’Ū.solve!(𝒎ℂ.EI_cache) + temp3[:,j] .= 𝒎ℂ.EI_cache.u + end + ℒ.mul!(E_new, E, temp3) # end # timeit_debug # @timeit_debug timer "Compute E" begin - # Compute E = E * EI * E - ℒ.ldiv!(temp3, fEI, E) - ℒ.mul!(E_new, E, temp3) - # E_new = E / fEI * E + # E_new = E / EI * E # end # timeit_debug # @timeit_debug timer "Compute FI" begin @@ -309,26 +329,34 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, # end # timeit_debug # @timeit_debug timer "Invert FI" begin - fFI = ℒ.lu!(temp2, check = false) - - if !ℒ.issuccess(fFI) - return A, iter, 1.0 + fFI_cache_alg = issparse(temp2) ? ð’Ū.UMFPACKFactorization() : ð’Ū.LUFactorization() + if !(typeof(𝒎ℂ.FI_cache.alg) === typeof(fFI_cache_alg)) || + size(𝒎ℂ.FI_cache.A) != size(temp2) || + length(𝒎ℂ.FI_cache.b) != size(F,1) + prob = ð’Ū.LinearProblem(copy(temp2), zeros(R, size(temp2,1)), fFI_cache_alg) + 𝒎ℂ.FI_cache = ð’Ū.init(prob, fFI_cache_alg) + else + copy!(𝒎ℂ.FI_cache.A, temp2) + 𝒎ℂ.FI_cache.isfresh = true + end + for j in 1:size(F,2) + copyto!(𝒎ℂ.FI_cache.b, view(F,:,j)) + ð’Ū.solve!(𝒎ℂ.FI_cache) + temp3[:,j] .= 𝒎ℂ.FI_cache.u end - - # end # timeit_debug - # @timeit_debug timer "Compute F" begin - - # Compute F = F * FI * F - ℒ.ldiv!(temp3, fFI, F) ℒ.mul!(F_new, F, temp3) - # F_new = F / fFI * F + # F_new = F / FI * F # end # timeit_debug # @timeit_debug timer "Compute X_new" begin # Compute X_new = X + F * FI * X * E ℒ.mul!(temp3, X, E) - ℒ.ldiv!(fFI, temp3) + for j in 1:size(temp3,2) + copyto!(𝒎ℂ.FI_cache.b, view(temp3,:,j)) + ð’Ū.solve!(𝒎ℂ.FI_cache) + temp3[:,j] .= 𝒎ℂ.FI_cache.u + end ℒ.mul!(X_new, F, temp3) # X_new = F / fFI * X * E if i > 5 || guess_provided @@ -343,7 +371,11 @@ function solve_quadratic_matrix_equation(A::AbstractMatrix{R}, # Compute Y_new = Y + E * EI * Y * F ℒ.mul!(X, Y, F) # use X as temporary storage - ℒ.ldiv!(fEI, X) + for j in 1:size(X,2) + copyto!(𝒎ℂ.EI_cache.b, view(X,:,j)) + ð’Ū.solve!(𝒎ℂ.EI_cache) + X[:,j] .= 𝒎ℂ.EI_cache.u + end ℒ.mul!(Y_new, E, X) # Y_new = E / fEI * Y * F if i > 5 || guess_provided diff --git a/src/filter/inversion.jl b/src/filter/inversion.jl index 7427b8d4b..93fa5131a 100644 --- a/src/filter/inversion.jl +++ b/src/filter/inversion.jl @@ -3457,10 +3457,11 @@ function filter_data_with_model(𝓂::â„ģ, ∇₁ = calculate_jacobian(𝓂.parameter_values, SS_and_pars, 𝓂)# |> Matrix - 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; - T = T, - initial_guess = 𝓂.solution.perturbation.qme_solution, - opts = opts) + 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; + T = T, + initial_guess = 𝓂.solution.perturbation.qme_solution, + opts = opts, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end diff --git a/src/filter/kalman.jl b/src/filter/kalman.jl index 767810b00..510ee03ee 100644 --- a/src/filter/kalman.jl +++ b/src/filter/kalman.jl @@ -601,7 +601,7 @@ function filter_and_smooth(𝓂::â„ģ, ∇₁ = calculate_jacobian(parameters, SS_and_pars, 𝓂)# |> Matrix - sol, qme_sol, solved = calculate_first_order_solution(∇₁; T = 𝓂.timings, opts = opts) + sol, qme_sol, solved = calculate_first_order_solution(∇₁; T = 𝓂.timings, opts = opts, 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end diff --git a/src/get_functions.jl b/src/get_functions.jl index baf25f97f..194bd9853 100644 --- a/src/get_functions.jl +++ b/src/get_functions.jl @@ -1025,10 +1025,11 @@ function get_irf(𝓂::â„ģ, ∇₁ = calculate_jacobian(parameters, reference_steady_state, 𝓂)# |> Matrix - sol_mat, qme_sol, solved = calculate_first_order_solution(∇₁; - T = 𝓂.timings, + sol_mat, qme_sol, solved = calculate_first_order_solution(∇₁; + T = 𝓂.timings, opts = opts, - initial_guess = 𝓂.solution.perturbation.qme_solution) + initial_guess = 𝓂.solution.perturbation.qme_solution, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol @@ -1987,9 +1988,10 @@ function get_solution(𝓂::â„ģ, ∇₁ = calculate_jacobian(parameters, SS_and_pars, 𝓂)# |> Matrix - 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; T = 𝓂.timings, + 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; T = 𝓂.timings, opts = opts, - initial_guess = 𝓂.solution.perturbation.qme_solution) + initial_guess = 𝓂.solution.perturbation.qme_solution, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end @@ -2172,10 +2174,11 @@ function get_conditional_variance_decomposition(𝓂::â„ģ; ∇₁ = calculate_jacobian(𝓂.parameter_values, SS_and_pars, 𝓂)# |> Matrix - 𝑚₁, qme_sol, solved = calculate_first_order_solution(∇₁; - T = 𝓂.timings, + 𝑚₁, qme_sol, solved = calculate_first_order_solution(∇₁; + T = 𝓂.timings, opts = opts, - initial_guess = 𝓂.solution.perturbation.qme_solution) + initial_guess = 𝓂.solution.perturbation.qme_solution, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end @@ -2331,10 +2334,11 @@ function get_variance_decomposition(𝓂::â„ģ; ∇₁ = calculate_jacobian(𝓂.parameter_values, SS_and_pars, 𝓂)# |> Matrix - sol, qme_sol, solved = calculate_first_order_solution(∇₁; - T = 𝓂.timings, - opts = opts, - initial_guess = 𝓂.solution.perturbation.qme_solution) + sol, qme_sol, solved = calculate_first_order_solution(∇₁; + T = 𝓂.timings, + opts = opts, + initial_guess = 𝓂.solution.perturbation.qme_solution, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end diff --git a/src/moments.jl b/src/moments.jl index 4d5982988..b473b2979 100644 --- a/src/moments.jl +++ b/src/moments.jl @@ -11,10 +11,11 @@ function calculate_covariance(parameters::Vector{R}, ∇₁ = calculate_jacobian(parameters, SS_and_pars, 𝓂) - sol, qme_sol, solved = calculate_first_order_solution(∇₁; - T = 𝓂.timings, - initial_guess = 𝓂.solution.perturbation.qme_solution, - opts = opts) + sol, qme_sol, solved = calculate_first_order_solution(∇₁; + T = 𝓂.timings, + initial_guess = 𝓂.solution.perturbation.qme_solution, + opts = opts, + 𝒎ℂ = 𝓂.caches.qme_caches) if solved 𝓂.solution.perturbation.qme_solution = qme_sol end @@ -56,10 +57,11 @@ function calculate_mean(parameters::Vector{T}, else ∇₁ = calculate_jacobian(parameters, SS_and_pars, 𝓂)# |> Matrix - 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; - T = 𝓂.timings, - initial_guess = 𝓂.solution.perturbation.qme_solution, - opts = opts) + 𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁; + T = 𝓂.timings, + initial_guess = 𝓂.solution.perturbation.qme_solution, + opts = opts, + 𝒎ℂ = 𝓂.caches.qme_caches) if !solved mean_of_variables = SS_and_pars[1:𝓂.timings.nVars] diff --git a/src/options_and_caches.jl b/src/options_and_caches.jl index 379f03558..c4899b3da 100644 --- a/src/options_and_caches.jl +++ b/src/options_and_caches.jl @@ -12,6 +12,21 @@ mutable struct sylvester_caches{G <: AbstractFloat} krylov_caches::krylov_caches{G} end +mutable struct qme_caches + EI_cache::ð’Ū.LinearCache + FI_cache::ð’Ū.LinearCache +end + +function QME_caches(;T::Type = Float64) + A = Matrix{T}(I, 1, 1) + b = zeros(T, 1) + alg = ð’Ū.LUFactorization() + prob = ð’Ū.LinearProblem(A, b, alg) + EI = ð’Ū.init(prob, alg) + FI = ð’Ū.init(prob, alg) + return qme_caches(EI, FI) +end + mutable struct higher_order_caches{F <: Real, G <: AbstractFloat} tmpkron0::SparseMatrixCSC{F, Int} tmpkron1::SparseMatrixCSC{F, Int} @@ -32,6 +47,7 @@ end mutable struct caches#{F <: Real, G <: AbstractFloat} second_order_caches::higher_order_caches#{F, G} third_order_caches::higher_order_caches#{F, G} + qme_caches::qme_caches end @@ -67,7 +83,8 @@ end function Caches(;T::Type = Float64, S::Type = Float64) caches( Higher_order_caches(T = T, S = S), - Higher_order_caches(T = T, S = S)) + Higher_order_caches(T = T, S = S), + QME_caches(T = S)) end diff --git a/src/perturbation.jl b/src/perturbation.jl index 5ed5c3592..7bb24232a 100644 --- a/src/perturbation.jl +++ b/src/perturbation.jl @@ -1,9 +1,10 @@ @stable default_mode = "disable" begin -function calculate_first_order_solution(∇₁::Matrix{R}; - T::timings, +function calculate_first_order_solution(∇₁::Matrix{R}; + T::timings, opts::CalculationOptions = merge_calculation_options(), - initial_guess::AbstractMatrix{R} = zeros(0,0))::Tuple{Matrix{R}, Matrix{R}, Bool} where R <: AbstractFloat + initial_guess::AbstractMatrix{R} = zeros(0,0), + 𝒎ℂ::qme_caches = QME_caches())::Tuple{Matrix{R}, Matrix{R}, Bool} where R <: AbstractFloat # @timeit_debug timer "Calculate 1st order solution" begin # @timeit_debug timer "Preprocessing" begin @@ -43,12 +44,13 @@ function calculate_first_order_solution(∇₁::Matrix{R}; # end # timeit_debug # @timeit_debug timer "Quadratic matrix equation solve" begin - sol, solved = solve_quadratic_matrix_equation(AĖƒâ‚Š, AĖƒâ‚€, AĖƒâ‚‹, T, + sol, solved = solve_quadratic_matrix_equation(AĖƒâ‚Š, AĖƒâ‚€, AĖƒâ‚‹, T, initial_guess = initial_guess, quadratic_matrix_equation_algorithm = opts.quadratic_matrix_equation_algorithm, tol = opts.tol.qme_tol, acceptance_tol = opts.tol.qme_acceptance_tol, - verbose = opts.verbose) + verbose = opts.verbose, + 𝒎ℂ = 𝒎ℂ) if !solved if opts.verbose println("Quadratic matrix equation solution failed.") end @@ -117,11 +119,12 @@ end end # dispatch_doctor -function rrule(::typeof(calculate_first_order_solution), +function rrule(::typeof(calculate_first_order_solution), ∇₁::Matrix{R}; - T::timings, + T::timings, opts::CalculationOptions = merge_calculation_options(), - initial_guess::AbstractMatrix{R} = zeros(0,0)) where R <: AbstractFloat + initial_guess::AbstractMatrix{R} = zeros(0,0), + 𝒎ℂ::qme_caches = QME_caches()) where R <: AbstractFloat # Forward pass to compute the output and intermediate values needed for the backward pass # @timeit_debug timer "Calculate 1st order solution" begin # @timeit_debug timer "Preprocessing" begin @@ -162,12 +165,13 @@ function rrule(::typeof(calculate_first_order_solution), # end # timeit_debug # @timeit_debug timer "Quadratic matrix equation solve" begin - sol, solved = solve_quadratic_matrix_equation(AĖƒâ‚Š, AĖƒâ‚€, AĖƒâ‚‹, T, + sol, solved = solve_quadratic_matrix_equation(AĖƒâ‚Š, AĖƒâ‚€, AĖƒâ‚‹, T, initial_guess = initial_guess, quadratic_matrix_equation_algorithm = opts.quadratic_matrix_equation_algorithm, tol = opts.tol.qme_tol, acceptance_tol = opts.tol.qme_acceptance_tol, - verbose = opts.verbose) + verbose = opts.verbose, + 𝒎ℂ = 𝒎ℂ) if !solved return (zeros(T.nVars,T.nPast_not_future_and_mixed + T.nExo), sol, false), x -> NoTangent(), NoTangent(), NoTangent() @@ -276,10 +280,11 @@ end @stable default_mode = "disable" begin -function calculate_first_order_solution(∇₁::Matrix{ℱ.Dual{Z,S,N}}; - T::timings, +function calculate_first_order_solution(∇₁::Matrix{ℱ.Dual{Z,S,N}}; + T::timings, opts::CalculationOptions = merge_calculation_options(), - initial_guess::AbstractMatrix{<:AbstractFloat} = zeros(0,0))::Tuple{Matrix{ℱ.Dual{Z,S,N}}, Matrix{Float64}, Bool} where {Z,S,N} + initial_guess::AbstractMatrix{<:AbstractFloat} = zeros(0,0), + 𝒎ℂ::qme_caches = QME_caches())::Tuple{Matrix{ℱ.Dual{Z,S,N}}, Matrix{Float64}, Bool} where {Z,S,N} âˆ‡Ė‚â‚ = ℱ.value.(∇₁) expand = [ℒ.I(T.nVars)[T.future_not_past_and_mixed_idx,:], ℒ.I(T.nVars)[T.past_not_future_and_mixed_idx,:]] @@ -287,7 +292,7 @@ function calculate_first_order_solution(∇₁::Matrix{ℱ.Dual{Z,S,N}}; A = âˆ‡Ė‚â‚[:,1:T.nFuture_not_past_and_mixed] * expand[1] B = âˆ‡Ė‚â‚[:,T.nFuture_not_past_and_mixed .+ range(1,T.nVars)] - 𝐒₁, qme_sol, solved = calculate_first_order_solution(âˆ‡Ė‚â‚; T = T, opts = opts, initial_guess = initial_guess) + 𝐒₁, qme_sol, solved = calculate_first_order_solution(âˆ‡Ė‚â‚; T = T, opts = opts, initial_guess = initial_guess, 𝒎ℂ = 𝒎ℂ) if !solved return ∇₁, qme_sol, false