Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion core/stop/iteration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include "ginkgo/core/stop/iteration.hpp"

#include "ginkgo/core/base/abstract_factory.hpp"
#include "core/stop/iteration.hpp"


namespace gko {
Expand All @@ -30,5 +30,14 @@ deferred_factory_parameter<Iteration::Factory> max_iters(size_type count)
}


deferred_factory_parameter<CriterionFactory> min_iters(
size_type count, deferred_factory_parameter<CriterionFactory> criterion)
{
return MinIterationWrapper::build()
.with_min_iters(count)
.with_inner_criterion(criterion);
}


} // namespace stop
} // namespace gko
76 changes: 76 additions & 0 deletions core/stop/iteration.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_STOP_ITERATION_HPP_
#define GKO_CORE_STOP_ITERATION_HPP_

#include <ginkgo/core/base/abstract_factory.hpp>
#include <ginkgo/core/stop/combined.hpp>
#include <ginkgo/core/stop/criterion.hpp>
#include <ginkgo/core/stop/iteration.hpp>

namespace gko {
namespace stop {


class MinIterationWrapper
: public EnablePolymorphicObject<MinIterationWrapper, Criterion> {
friend class EnablePolymorphicObject<MinIterationWrapper, Criterion>;

public:
GKO_CREATE_FACTORY_PARAMETERS(parameters, Factory)
{
/**
* Minimum number of iterations, after which we check the inner
* criterion
*/
size_type min_iters{0};

parameters_type& with_min_iters(size_type value)
{
this->min_iters = value;
return *this;
}

std::shared_ptr<const CriterionFactory> GKO_DEFERRED_FACTORY_PARAMETER(
inner_criterion);
};
GKO_ENABLE_CRITERION_FACTORY(MinIterationWrapper, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);

protected:
bool check_impl(uint8 stopping_id, bool set_finalized,
array<stopping_status>* stop_status, bool* one_changed,
const Updater& updater) override
{
if (updater.num_iterations_ < this->get_parameters().min_iters) {
return false;
}
return inner_criterion_->check(stopping_id, set_finalized, stop_status,
one_changed, updater);
}

explicit MinIterationWrapper(std::shared_ptr<const gko::Executor> exec)
: EnablePolymorphicObject<MinIterationWrapper, Criterion>(
std::move(exec))
{}

explicit MinIterationWrapper(const Factory* factory,
const CriterionArgs& args)
: EnablePolymorphicObject<MinIterationWrapper, Criterion>(
factory->get_executor()),
parameters_{factory->get_parameters()},
inner_criterion_{
factory->get_parameters().inner_criterion->generate(args)}
{}

std::shared_ptr<Criterion> inner_criterion_;
};


} // namespace stop
} // namespace gko


#endif // GKO_CORE_STOP_ITERATION_HPP_
24 changes: 23 additions & 1 deletion core/test/stop/iteration.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "core/stop/iteration.hpp"

#include <gtest/gtest.h>

#include <ginkgo/core/stop/iteration.hpp>
Expand Down Expand Up @@ -41,4 +43,24 @@ TEST_F(Iteration, CanCreateCriterion)
}


TEST_F(Iteration, CanCreateMinIterationWithInnerCriterion)
{
auto factory = gko::as<gko::stop::MinIterationWrapper::Factory>(
gko::stop::min_iters(10, gko::stop::max_iters(100),
gko::stop::max_iters(1000))
.on(exec_));

auto inner = gko::as<gko::stop::Combined::Factory>(
factory->get_parameters().inner_criterion);
ASSERT_EQ(factory->get_parameters().min_iters, 10);
ASSERT_EQ(inner->get_parameters().criteria.size(), 2);
auto inner1 = gko::as<gko::stop::Iteration::Factory>(
inner->get_parameters().criteria.at(0));
auto inner2 = gko::as<gko::stop::Iteration::Factory>(
inner->get_parameters().criteria.at(1));
ASSERT_EQ(inner1->get_parameters().max_iters, 100);
ASSERT_EQ(inner2->get_parameters().max_iters, 1000);
}


} // namespace
49 changes: 47 additions & 2 deletions include/ginkgo/core/stop/iteration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#define GKO_PUBLIC_CORE_STOP_ITERATION_HPP_


#include <ginkgo/core/base/abstract_factory.hpp>
#include <ginkgo/core/stop/combined.hpp>
#include <ginkgo/core/stop/criterion.hpp>

#include "ginkgo/core/base/abstract_factory.hpp"


namespace gko {
namespace stop {
Expand Down Expand Up @@ -83,6 +83,51 @@ class Iteration : public EnablePolymorphicObject<Iteration, Criterion> {
deferred_factory_parameter<Iteration::Factory> max_iters(size_type count);


/**
* Creates the precursor to an MinimumIteration stopping criterion factory, to
* be used in conjunction with `.with_criteria(...)` function calls when
* building a solver factory. This stopping criterion wraps another stopping
* criterion inside, which only starts getting checked after the first `count`
* iterations finished.
*
* Full usage example: Stop when the relative residual
* norm is below $10^{-10}$, but with at least 100 iterations.
* ```cpp
* auto factory = gko::solver::Cg<double>::build()
* .with_criteria(gko::stop::min_iters(100,
* gko::stop::relative_residual_norm(1e-10)))
* .on(exec);
* ```
*
* @param count the number of iterations after which to start checking the
* inner criterion
Comment on lines +102 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* @param count the number of iterations after which to start checking the
* inner criterion
* @param count the number of iterations after which to start checking the
* inner criterion

nit

* @param criterion the inner criterion, which will not be checked until
* `count` iterations finished, afterwards the min_iters
* stopping criterion behaves like the inner criterion.
* @return a deferred_factory_parameter that can be passed to the
Comment on lines +106 to +107
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* stopping criterion behaves like the inner criterion.
* @return a deferred_factory_parameter that can be passed to the
* stopping criterion behaves like the inner criterion.
*
* @return a deferred_factory_parameter that can be passed to the

nit

* `with_criteria` function when building a solver.
*/
deferred_factory_parameter<CriterionFactory> min_iters(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some documentation would be needed here, so that users know how to use this wrapper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea how this can be realized in the file config? From the user side this:

criteria:
  max_iters: 100
  relative_residual_norm: 1e-6
  min_iters: 10

should be pretty clear, i.e. the first two only activate once min_iters is hit, even though the list of criteria is usually an OR list.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good from the first sight. but feel weird trying explaining it in consistent way

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

require_iters sounds stronger but it does not fit the class name

size_type count, deferred_factory_parameter<CriterionFactory> criterion);


/**
* @copydoc min_iters(size_type, deferred_factory_parameter<CriterionFactory>)
* This version supports supplying multiple stopping criteria independently, all
* of which will only be checked after the minimum iteration count has been
* exceeded.
*/
template <typename... Args>
std::enable_if_t<sizeof...(Args) >= 2,
deferred_factory_parameter<CriterionFactory>>
min_iters(size_type count, Args&&... criteria)
{
std::vector<deferred_factory_parameter<CriterionFactory>> criterion_vec{
std::forward<Args>(criteria)...};
return min_iters(count, Combined::build().with_criteria(criterion_vec));
};


} // namespace stop
} // namespace gko

Expand Down
60 changes: 60 additions & 0 deletions reference/test/stop/residual_norm_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <gtest/gtest.h>

#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/stop/iteration.hpp>
#include <ginkgo/core/stop/residual_norm.hpp>

#include "core/test/utils.hpp"
Expand Down Expand Up @@ -510,6 +511,65 @@ TYPED_TEST(ResidualNorm, WaitsTillResidualGoalMultipleRHS)
}


TYPED_TEST(ResidualNorm, WorksWithMinIterationCount)
{
using Mtx = typename TestFixture::Mtx;
using NormVector = typename TestFixture::NormVector;
using T_nc = gko::remove_complex<TypeParam>;
auto initial_res = gko::initialize<Mtx>({100.0}, this->exec_);
std::shared_ptr<gko::LinOp> rhs = gko::initialize<Mtx>({10.0}, this->exec_);
auto min_factory = gko::stop::min_iters(
10, gko::stop::ResidualNorm<TypeParam>::build()
.with_baseline(gko::stop::mode::absolute)
.with_reduction_factor(r<TypeParam>::value))
.on(this->exec_);
auto min_criterion =
min_factory->generate(nullptr, rhs, nullptr, initial_res.get());
{
auto res_norm = gko::initialize<NormVector>({100.0}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_FALSE(
min_criterion->update()
.num_iterations(9)
.residual_norm(res_norm)
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

res_norm->at(0) = r<TypeParam>::value * 0.9;
ASSERT_FALSE(
min_criterion->update()
.num_iterations(9)
.residual_norm(res_norm)
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

res_norm->at(0) = r<TypeParam>::value * 1.1;
ASSERT_FALSE(
min_criterion->update()
.num_iterations(10)
.residual_norm(res_norm)
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

res_norm->at(0) = r<TypeParam>::value * 0.9;
ASSERT_TRUE(
min_criterion->update()
.num_iterations(10)
.residual_norm(res_norm)
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), true);
ASSERT_EQ(one_changed, true);
}
}


TYPED_TEST(ResidualNorm, SimplifiedInterface)
{
using Mtx = typename TestFixture::Mtx;
Expand Down
Loading