diff --git a/core/stop/iteration.cpp b/core/stop/iteration.cpp index e7c39f4c77a..da676df9756 100644 --- a/core/stop/iteration.cpp +++ b/core/stop/iteration.cpp @@ -4,7 +4,7 @@ #include "ginkgo/core/stop/iteration.hpp" -#include "ginkgo/core/base/abstract_factory.hpp" +#include "core/stop/iteration.hpp" namespace gko { @@ -30,5 +30,14 @@ deferred_factory_parameter max_iters(size_type count) } +deferred_factory_parameter min_iters( + size_type count, deferred_factory_parameter criterion) +{ + return MinIterationWrapper::build() + .with_min_iters(count) + .with_inner_criterion(criterion); +} + + } // namespace stop } // namespace gko diff --git a/core/stop/iteration.hpp b/core/stop/iteration.hpp new file mode 100644 index 00000000000..e57cad21b62 --- /dev/null +++ b/core/stop/iteration.hpp @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_STOP_ITERATION_HPP_ +#define GKO_CORE_STOP_ITERATION_HPP_ + +#include +#include +#include +#include + +namespace gko { +namespace stop { + + +class MinIterationWrapper + : public EnablePolymorphicObject { + friend class EnablePolymorphicObject; + +public: + GKO_CREATE_FACTORY_PARAMETERS(parameters, Factory) + { + /** + * Minimum number of iterations, after which we check the inner + * criterion + */ + size_type min_iters{0}; + + parameters_type& with_min_iters(size_type value) + { + this->min_iters = value; + return *this; + } + + std::shared_ptr GKO_DEFERRED_FACTORY_PARAMETER( + inner_criterion); + }; + GKO_ENABLE_CRITERION_FACTORY(MinIterationWrapper, parameters, Factory); + GKO_ENABLE_BUILD_METHOD(Factory); + +protected: + bool check_impl(uint8 stopping_id, bool set_finalized, + array* stop_status, bool* one_changed, + const Updater& updater) override + { + if (updater.num_iterations_ < this->get_parameters().min_iters) { + return false; + } + return inner_criterion_->check(stopping_id, set_finalized, stop_status, + one_changed, updater); + } + + explicit MinIterationWrapper(std::shared_ptr exec) + : EnablePolymorphicObject( + std::move(exec)) + {} + + explicit MinIterationWrapper(const Factory* factory, + const CriterionArgs& args) + : EnablePolymorphicObject( + factory->get_executor()), + parameters_{factory->get_parameters()}, + inner_criterion_{ + factory->get_parameters().inner_criterion->generate(args)} + {} + + std::shared_ptr inner_criterion_; +}; + + +} // namespace stop +} // namespace gko + + +#endif // GKO_CORE_STOP_ITERATION_HPP_ diff --git a/core/test/stop/iteration.cpp b/core/test/stop/iteration.cpp index e538885e5d6..37fd2f58b22 100644 --- a/core/test/stop/iteration.cpp +++ b/core/test/stop/iteration.cpp @@ -1,7 +1,9 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause +#include "core/stop/iteration.hpp" + #include #include @@ -41,4 +43,24 @@ TEST_F(Iteration, CanCreateCriterion) } +TEST_F(Iteration, CanCreateMinIterationWithInnerCriterion) +{ + auto factory = gko::as( + gko::stop::min_iters(10, gko::stop::max_iters(100), + gko::stop::max_iters(1000)) + .on(exec_)); + + auto inner = gko::as( + factory->get_parameters().inner_criterion); + ASSERT_EQ(factory->get_parameters().min_iters, 10); + ASSERT_EQ(inner->get_parameters().criteria.size(), 2); + auto inner1 = gko::as( + inner->get_parameters().criteria.at(0)); + auto inner2 = gko::as( + inner->get_parameters().criteria.at(1)); + ASSERT_EQ(inner1->get_parameters().max_iters, 100); + ASSERT_EQ(inner2->get_parameters().max_iters, 1000); +} + + } // namespace diff --git a/include/ginkgo/core/stop/iteration.hpp b/include/ginkgo/core/stop/iteration.hpp index 5d9e848b6c7..04cb73710cb 100644 --- a/include/ginkgo/core/stop/iteration.hpp +++ b/include/ginkgo/core/stop/iteration.hpp @@ -6,10 +6,10 @@ #define GKO_PUBLIC_CORE_STOP_ITERATION_HPP_ +#include +#include #include -#include "ginkgo/core/base/abstract_factory.hpp" - namespace gko { namespace stop { @@ -83,6 +83,51 @@ class Iteration : public EnablePolymorphicObject { deferred_factory_parameter max_iters(size_type count); +/** + * Creates the precursor to an MinimumIteration stopping criterion factory, to + * be used in conjunction with `.with_criteria(...)` function calls when + * building a solver factory. This stopping criterion wraps another stopping + * criterion inside, which only starts getting checked after the first `count` + * iterations finished. + * + * Full usage example: Stop when the relative residual + * norm is below $10^{-10}$, but with at least 100 iterations. + * ```cpp + * auto factory = gko::solver::Cg::build() + * .with_criteria(gko::stop::min_iters(100, + * gko::stop::relative_residual_norm(1e-10))) + * .on(exec); + * ``` + * + * @param count the number of iterations after which to start checking the + * inner criterion + * @param criterion the inner criterion, which will not be checked until + * `count` iterations finished, afterwards the min_iters + * stopping criterion behaves like the inner criterion. + * @return a deferred_factory_parameter that can be passed to the + * `with_criteria` function when building a solver. + */ +deferred_factory_parameter min_iters( + size_type count, deferred_factory_parameter criterion); + + +/** + * @copydoc min_iters(size_type, deferred_factory_parameter) + * This version supports supplying multiple stopping criteria independently, all + * of which will only be checked after the minimum iteration count has been + * exceeded. + */ +template +std::enable_if_t= 2, + deferred_factory_parameter> +min_iters(size_type count, Args&&... criteria) +{ + std::vector> criterion_vec{ + std::forward(criteria)...}; + return min_iters(count, Combined::build().with_criteria(criterion_vec)); +}; + + } // namespace stop } // namespace gko diff --git a/reference/test/stop/residual_norm_kernels.cpp b/reference/test/stop/residual_norm_kernels.cpp index 024b83377aa..100e9955e3d 100644 --- a/reference/test/stop/residual_norm_kernels.cpp +++ b/reference/test/stop/residual_norm_kernels.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include "core/test/utils.hpp" @@ -510,6 +511,65 @@ TYPED_TEST(ResidualNorm, WaitsTillResidualGoalMultipleRHS) } +TYPED_TEST(ResidualNorm, WorksWithMinIterationCount) +{ + using Mtx = typename TestFixture::Mtx; + using NormVector = typename TestFixture::NormVector; + using T_nc = gko::remove_complex; + auto initial_res = gko::initialize({100.0}, this->exec_); + std::shared_ptr rhs = gko::initialize({10.0}, this->exec_); + auto min_factory = gko::stop::min_iters( + 10, gko::stop::ResidualNorm::build() + .with_baseline(gko::stop::mode::absolute) + .with_reduction_factor(r::value)) + .on(this->exec_); + auto min_criterion = + min_factory->generate(nullptr, rhs, nullptr, initial_res.get()); + { + auto res_norm = gko::initialize({100.0}, this->exec_); + constexpr gko::uint8 RelativeStoppingId{1}; + bool one_changed{}; + gko::array stop_status(this->exec_, 1); + stop_status.get_data()[0].reset(); + + ASSERT_FALSE( + min_criterion->update() + .num_iterations(9) + .residual_norm(res_norm) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), false); + ASSERT_EQ(one_changed, false); + + res_norm->at(0) = r::value * 0.9; + ASSERT_FALSE( + min_criterion->update() + .num_iterations(9) + .residual_norm(res_norm) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), false); + ASSERT_EQ(one_changed, false); + + res_norm->at(0) = r::value * 1.1; + ASSERT_FALSE( + min_criterion->update() + .num_iterations(10) + .residual_norm(res_norm) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), false); + ASSERT_EQ(one_changed, false); + + res_norm->at(0) = r::value * 0.9; + ASSERT_TRUE( + min_criterion->update() + .num_iterations(10) + .residual_norm(res_norm) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), true); + ASSERT_EQ(one_changed, true); + } +} + + TYPED_TEST(ResidualNorm, SimplifiedInterface) { using Mtx = typename TestFixture::Mtx;