Skip to content

Commit 184f792

Browse files
committed
reuse initial full proposal step in laplace
1 parent c99fbcf commit 184f792

File tree

4 files changed

+251
-59
lines changed

4 files changed

+251
-59
lines changed

stan/math/mix/functor/laplace_marginal_density_estimator.hpp

Lines changed: 54 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ struct NewtonState {
342342
/** @brief Status of the most recent Wolfe line search */
343343
WolfeStatus wolfe_status;
344344

345+
/** @brief Cached proposal evaluated before the Wolfe line search. */
346+
WolfeData proposal;
347+
345348
/** @brief Workspace vector: b = W * theta + grad(log_lik) */
346349
Eigen::VectorXd b;
347350

@@ -377,6 +380,7 @@ struct NewtonState {
377380
: wolfe_info(std::forward<ObjFun>(obj_fun), covariance.llt().solve(theta_init),
378381
std::forward<ThetaInitializer>(theta_init),
379382
std::forward<ThetaGradFun>(theta_grad_f)),
383+
proposal(theta_size),
380384
b(theta_size),
381385
B(theta_size, theta_size),
382386
prev_g(theta_size) {
@@ -407,9 +411,12 @@ struct NewtonState {
407411
*/
408412
const auto& prev() const& { return wolfe_info.prev_; }
409413
auto&& prev() && { return std::move(wolfe_info).prev(); }
414+
auto& proposal_step() & { return proposal; }
415+
const auto& proposal_step() const& { return proposal; }
416+
auto&& proposal_step() && { return std::move(proposal); }
410417
template <typename Options>
411418
inline void update_next_step(const Options& options) {
412-
this->prev().update(this->curr());
419+
this->prev().swap(this->curr());
413420
this->curr().alpha()
414421
= std::clamp(this->curr().alpha(), 0.0, options.line_search.max_alpha);
415422
}
@@ -485,7 +492,8 @@ struct CholeskyWSolverDiag {
485492
* @tparam LLFun Type of the log-likelihood functor
486493
* @tparam LLTupleArgs Type of the likelihood arguments tuple
487494
* @tparam CovarMat Type of the covariance matrix
488-
* @param[in,out] state Shared Newton state (modified: B, b, curr().a())
495+
* @param[in,out] state Shared Newton state (modified: B, b,
496+
* proposal_step().a())
489497
* @param[in] ll_fun Log-likelihood functor
490498
* @param[in,out] ll_args Additional arguments for the likelihood
491499
* @param[in] covariance Prior covariance matrix Sigma
@@ -521,12 +529,12 @@ struct CholeskyWSolverDiag {
521529

522530
// 3. Factorize B with jittering fallback
523531
llt_with_jitter(llt_B, state.B);
524-
// 4. Solve for curr.a
532+
// 4. Solve for the raw Newton proposal in a-space.
525533
state.b.noalias() = (W_diag.array() * state.prev().theta().array()).matrix()
526534
+ state.prev().theta_grad();
527535
auto L = llt_B.matrixL();
528536
auto LT = llt_B.matrixU();
529-
state.curr().a().noalias()
537+
state.proposal_step().a().noalias()
530538
= state.b
531539
- W_r_diag.asDiagonal()
532540
* LT.solve(
@@ -615,7 +623,8 @@ struct CholeskyWSolverBlock {
615623
* @tparam LLFun Type of the log-likelihood functor
616624
* @tparam LLTupleArgs Type of the likelihood arguments tuple
617625
* @tparam CovarMat Type of the covariance matrix
618-
* @param[in,out] state Shared Newton state (modified: B, b, curr().a())
626+
* @param[in,out] state Shared Newton state (modified: B, b,
627+
* proposal_step().a())
619628
* @param[in] ll_fun Log-likelihood functor
620629
* @param[in,out] ll_args Additional arguments for the likelihood
621630
* @param[in] covariance Prior covariance matrix Sigma
@@ -653,12 +662,12 @@ struct CholeskyWSolverBlock {
653662
// 4. Factorize B with jittering fallback
654663
llt_with_jitter(llt_B, state.B);
655664

656-
// 5. Solve for curr.a
665+
// 5. Solve for the raw Newton proposal in a-space.
657666
state.b.noalias()
658667
= W_block * state.prev().theta() + state.prev().theta_grad();
659668
auto L = llt_B.matrixL();
660669
auto LT = llt_B.matrixU();
661-
state.curr().a().noalias()
670+
state.proposal_step().a().noalias()
662671
= state.b - W_r * LT.solve(L.solve(W_r * (covariance * state.b)));
663672
}
664673

@@ -736,7 +745,7 @@ struct CholeskyKSolver {
736745
* @tparam LLFun Type of the log-likelihood functor
737746
* @tparam LLTupleArgs Type of the likelihood arguments tuple
738747
* @tparam CovarMat Type of the covariance matrix
739-
* @param[in] state Shared Newton state (modified: B, b, curr().a())
748+
* @param[in] state Shared Newton state (modified: B, b, proposal_step().a())
740749
* @param[in] ll_fun Log-likelihood functor
741750
* @param[in] ll_args Additional arguments for the likelihood
742751
* @param[in] covariance Prior covariance matrix Sigma
@@ -763,12 +772,12 @@ struct CholeskyKSolver {
763772
// 3. Factorize B with jittering fallback
764773
llt_with_jitter(llt_B, state.B);
765774

766-
// 4. Solve for curr.a
775+
// 4. Solve for the raw Newton proposal in a-space.
767776
state.b.noalias()
768777
= W_full * state.prev().theta() + state.prev().theta_grad();
769778
auto L = llt_B.matrixL();
770779
auto LT = llt_B.matrixU();
771-
state.curr().a().noalias()
780+
state.proposal_step().a().noalias()
772781
= K_root.transpose().template triangularView<Eigen::Upper>().solve(
773782
LT.solve(L.solve(K_root.transpose() * state.b)));
774783
}
@@ -833,7 +842,7 @@ struct LUSolver {
833842
* @tparam LLFun Type of the log-likelihood functor
834843
* @tparam LLTupleArgs Type of the likelihood arguments tuple
835844
* @tparam CovarMat Type of the covariance matrix
836-
* @param[in,out] state Shared Newton state (modified: b, curr().a())
845+
* @param[in,out] state Shared Newton state (modified: b, proposal_step().a())
837846
* @param[in] ll_fun Log-likelihood functor
838847
* @param[in,out] ll_args Additional arguments for the likelihood
839848
* @param[in] covariance Prior covariance matrix Sigma
@@ -855,10 +864,10 @@ struct LUSolver {
855864
lu.compute(Eigen::MatrixXd::Identity(theta_size, theta_size)
856865
+ covariance * W_full);
857866

858-
// 3. Solve for curr.a
867+
// 3. Solve for the raw Newton proposal in a-space.
859868
state.b.noalias()
860869
= W_full * state.prev().theta() + state.prev().theta_grad();
861-
state.curr().a().noalias()
870+
state.proposal_step().a().noalias()
862871
= state.b - W_full * lu.solve(covariance * state.b);
863872
}
864873

@@ -932,29 +941,32 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
932941
solver.solve_step(state, ll_fun, ll_args, covariance,
933942
options.hessian_block_size, msgs);
934943
if (!state.final_loop) {
935-
state.wolfe_info.p_ = state.curr().a() - state.prev().a();
944+
auto&& proposal = state.proposal_step();
945+
state.wolfe_info.p_ = proposal.a() - state.prev().a();
936946
state.prev_g.noalias() = -covariance * state.prev().a()
937947
+ covariance * state.prev().theta_grad();
938948
state.wolfe_info.init_dir_ = state.prev_g.dot(state.wolfe_info.p_);
939949
// Flip direction if not ascending
940950
state.wolfe_info.flip_direction();
941951
auto&& scratch = state.wolfe_info.scratch_;
942-
scratch.alpha() = 1.0;
943-
update_fun(scratch, state.curr(), state.prev(), scratch.eval_,
944-
state.wolfe_info.p_);
945-
// Save the full Newton step objective before the Wolfe line search
946-
// overwrites scratch with intermediate trial points.
947-
const double full_newton_obj = scratch.eval_.obj();
948-
if (scratch.alpha() <= options.line_search.min_alpha) {
949-
state.wolfe_status.accept_ = false;
950-
finish_update = true;
952+
proposal.eval_.alpha() = 1.0;
953+
const bool proposal_valid = update_fun(
954+
proposal, state.curr(), state.prev(), proposal.eval_,
955+
state.wolfe_info.p_);
956+
const bool cached_proposal_ok
957+
= proposal_valid && std::isfinite(proposal.obj())
958+
&& std::isfinite(proposal.dir())
959+
&& proposal.alpha() > options.line_search.min_alpha;
960+
if (!cached_proposal_ok) {
961+
state.wolfe_status
962+
= WolfeStatus{WolfeReturn::StepTooSmall, 1, 0, false};
951963
} else if (options.line_search.max_iterations == 0) {
952-
state.curr().update(scratch);
953-
state.wolfe_status.accept_ = true;
964+
state.curr().update(proposal);
965+
state.wolfe_status = WolfeStatus{WolfeReturn::Continue, 1, 0, true};
954966
} else {
955-
Eigen::VectorXd s = scratch.a() - state.prev().a();
967+
Eigen::VectorXd s = proposal.a() - state.prev().a();
956968
auto full_step_grad
957-
= (-covariance * scratch.a() + covariance * scratch.theta_grad())
969+
= (-covariance * proposal.a() + covariance * proposal.theta_grad())
958970
.eval();
959971
state.curr().alpha() = barzilai_borwein_step_size(
960972
s, full_step_grad, state.prev_g, state.prev().alpha(),
@@ -963,47 +975,30 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
963975
state.wolfe_status = internal::wolfe_line_search(
964976
state.wolfe_info, update_fun, options.line_search, msgs);
965977
}
966-
// When the Wolfe line search rejects, don't immediately terminate.
967-
// Instead, let the Newton loop try at least one more iteration.
968-
// The original code compared the stale curr.obj() (which equalled
969-
// prev.obj() after the swap in update_next_step) and would always
970-
// terminate on ANY Wolfe rejection — even on the very first Newton
971-
// step. Now we only declare search_failed if the full Newton step
972-
// itself didn't improve the objective.
973-
bool search_failed;
974-
if (!state.wolfe_status.accept_) {
975-
if (full_newton_obj > state.prev().obj()) {
976-
// The full Newton step (evaluated before Wolfe ran) improved
977-
// the objective. Re-evaluate scratch at the full Newton step
978-
// so we can accept it as the current iterate.
979-
scratch.eval_.alpha() = 1.0;
980-
update_fun(scratch, state.curr(), state.prev(), scratch.eval_,
981-
state.wolfe_info.p_);
982-
state.curr().update(scratch);
983-
state.wolfe_status.accept_ = true;
984-
search_failed = false;
985-
} else {
986-
search_failed = true;
987-
}
988-
} else {
978+
bool search_failed = !state.wolfe_status.accept_;
979+
const bool proposal_armijo_ok
980+
= cached_proposal_ok
981+
&& internal::check_armijo(
982+
proposal.obj(), state.prev().obj(), proposal.alpha(),
983+
state.wolfe_info.init_dir_, options.line_search);
984+
if (search_failed && proposal_armijo_ok) {
985+
state.curr().update(proposal);
986+
state.wolfe_status = WolfeStatus{WolfeReturn::Armijo,
987+
state.wolfe_status.num_evals_,
988+
state.wolfe_status.num_backtracks_,
989+
true};
989990
search_failed = false;
990991
}
991-
/**
992-
* Stop when objective change is small (absolute AND relative), or when
993-
* a rejected Wolfe step fails to improve; finish_update then exits the
994-
* Newton loop.
995-
*/
996-
double obj_change = std::abs(state.curr().obj() - state.prev().obj());
997992
bool objective_converged
998-
= obj_change < options.tolerance
999-
&& obj_change < options.tolerance * std::abs(state.prev().obj());
993+
= state.wolfe_status.accept_
994+
&& std::abs(state.curr().obj() - state.prev().obj())
995+
< options.tolerance;
1000996
finish_update = objective_converged || search_failed;
1001997
}
1002998
if (finish_update) {
1003999
if (!state.final_loop && state.wolfe_status.accept_) {
10041000
// Do one final loop with exact wolfe conditions
10051001
state.final_loop = true;
1006-
// NOTE: Swapping here so we need to swap prev and curr later
10071002
state.update_next_step(options);
10081003
continue;
10091004
}

stan/math/mix/functor/wolfe_line_search.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,12 @@ struct WolfeData {
461461
a_.swap(other.a_);
462462
eval_ = other.eval_;
463463
}
464+
void swap(WolfeData& other) {
465+
theta_.swap(other.theta_);
466+
theta_grad_.swap(other.theta_grad_);
467+
a_.swap(other.a_);
468+
std::swap(eval_, other.eval_);
469+
}
464470
void update(WolfeData& other, const Eval& eval) {
465471
theta_.swap(other.theta_);
466472
a_.swap(other.a_);
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#include <gtest/gtest.h>
2+
#include <stan/math.hpp>
3+
#include <stan/math/mix.hpp>
4+
5+
#include <cmath>
6+
#include <sstream>
7+
#include <tuple>
8+
9+
namespace stan::math {
10+
namespace {
11+
12+
struct IdentityCovariance {
13+
template <typename Stream>
14+
Eigen::MatrixXd operator()(Stream* /*msgs*/) const {
15+
return Eigen::MatrixXd::Identity(1, 1);
16+
}
17+
};
18+
19+
struct QuarticLikelihood {
20+
template <typename Theta>
21+
auto operator()(const Theta& theta, std::ostream* /*msgs*/) const {
22+
const auto& x = theta(0);
23+
const auto x_sq = stan::math::square(x);
24+
return 2.0 * x - 0.5 * x_sq - 0.5 * stan::math::square(x_sq);
25+
}
26+
};
27+
28+
struct TinyQuarticLikelihood {
29+
template <typename Theta>
30+
auto operator()(const Theta& theta, std::ostream* /*msgs*/) const {
31+
return 1e-8 * QuarticLikelihood{}(theta, nullptr);
32+
}
33+
};
34+
35+
struct StubNewtonSolver {
36+
double proposal_a;
37+
38+
template <typename NewtonStateT, typename LLFun, typename LLTupleArgs,
39+
typename CovarMat>
40+
void solve_step(NewtonStateT& state, const LLFun& /*ll_fun*/,
41+
const LLTupleArgs& /*ll_args*/,
42+
const CovarMat& /*covariance*/, int /*hessian_block_size*/,
43+
std::ostream* /*msgs*/) const {
44+
state.proposal_step().a()(0) = proposal_a;
45+
}
46+
47+
double compute_log_determinant() const { return 0.0; }
48+
49+
template <typename NewtonStateT>
50+
double build_result(NewtonStateT& state, double /*log_det*/) const {
51+
return state.prev().a()(0);
52+
}
53+
};
54+
55+
template <typename Likelihood>
56+
double run_laplace(const Likelihood& ll_fun, double theta0_value,
57+
double tolerance, int max_num_steps,
58+
int max_steps_line_search, std::ostream* msgs) {
59+
Eigen::VectorXd theta0(1);
60+
theta0 << theta0_value;
61+
return stan::math::laplace_marginal_tol<false>(
62+
ll_fun, std::tuple<>{}, 1, IdentityCovariance{}, std::tuple<>{},
63+
std::make_tuple(theta0, tolerance, max_num_steps, 1,
64+
max_steps_line_search, true),
65+
msgs);
66+
}
67+
68+
TEST(LaplaceMarginalDensityEstimator, PublicLineSearchMatchesDirectStep) {
69+
std::ostringstream no_search_msgs;
70+
std::ostringstream wolfe_msgs;
71+
72+
const double no_search = run_laplace(QuarticLikelihood{}, 2.0, 1e-12, 50, 0,
73+
&no_search_msgs);
74+
const double with_wolfe = run_laplace(QuarticLikelihood{}, 2.0, 1e-12, 50,
75+
1000, &wolfe_msgs);
76+
77+
EXPECT_TRUE(std::isfinite(no_search));
78+
EXPECT_TRUE(std::isfinite(with_wolfe));
79+
EXPECT_NEAR(no_search, with_wolfe, 1e-8);
80+
}
81+
82+
TEST(LaplaceMarginalDensityEstimator, AbsoluteObjectiveToleranceStopsNearZero) {
83+
std::ostringstream msgs;
84+
85+
const double result
86+
= run_laplace(TinyQuarticLikelihood{}, 0.0, 1e-8, 6, 1000, &msgs);
87+
88+
EXPECT_TRUE(std::isfinite(result));
89+
EXPECT_EQ(msgs.str().find("max number of iterations"), std::string::npos);
90+
}
91+
92+
TEST(LaplaceMarginalDensityEstimator,
93+
InvalidCachedProposalDoesNotTriggerArmijoFallback) {
94+
Eigen::MatrixXd covariance = Eigen::MatrixXd::Identity(1, 1);
95+
Eigen::VectorXd theta0 = Eigen::VectorXd::Zero(1);
96+
auto obj_fun = [](const auto& /*a*/, const auto& /*theta*/) {
97+
return -1.0;
98+
};
99+
auto theta_grad_f = [](const auto& theta) {
100+
return Eigen::VectorXd::Zero(theta.size());
101+
};
102+
internal::NewtonState state(1, obj_fun, theta_grad_f, covariance, theta0);
103+
laplace_options_base options;
104+
options.hessian_block_size = 1;
105+
options.max_num_steps = 1;
106+
options.tolerance = 1e-12;
107+
options.line_search.max_iterations = 5;
108+
options.line_search.min_alpha = 1e-8;
109+
110+
StubNewtonSolver solver{5.0};
111+
Eigen::Index step_iter = 1;
112+
auto failing_update = [min_alpha = options.line_search.min_alpha](
113+
auto& /*proposal*/, auto&& /*curr*/,
114+
auto&& /*prev*/, auto& eval_in, auto&& /*p*/) {
115+
eval_in.alpha() = 0.5 * min_alpha;
116+
return false;
117+
};
118+
auto unused_ll = [](const auto& /*theta*/, std::ostream* /*msgs*/) {
119+
return 0.0;
120+
};
121+
122+
const double result
123+
= internal::run_newton_loop(solver, state, options, step_iter, unused_ll,
124+
std::tuple<>{}, covariance, failing_update,
125+
nullptr);
126+
127+
EXPECT_DOUBLE_EQ(result, 0.0);
128+
EXPECT_FALSE(state.wolfe_status.accept_);
129+
EXPECT_EQ(state.wolfe_status.stop_, internal::WolfeReturn::StepTooSmall);
130+
}
131+
132+
} // namespace
133+
} // namespace stan::math

0 commit comments

Comments
 (0)