Skip to content

Commit e45cab5

Browse files
committed
add proposed simplified residual norm interface
1 parent 7348366 commit e45cab5

File tree

8 files changed

+130
-22
lines changed

8 files changed

+130
-22
lines changed

core/stop/iteration.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "ginkgo/core/stop/iteration.hpp"
66

7+
#include "ginkgo/core/base/abstract_factory.hpp"
8+
79

810
namespace gko {
911
namespace stop {
@@ -22,5 +24,11 @@ bool Iteration::check_impl(uint8 stoppingId, bool setFinalized,
2224
}
2325

2426

27+
deferred_factory_parameter<Iteration::Factory> iteration(size_type count)
28+
{
29+
return Iteration::build().with_max_iters(count);
30+
}
31+
32+
2533
} // namespace stop
2634
} // namespace gko

core/stop/residual_norm.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,13 @@ class ResidualNormFactory
298298
"stopping criterion threshold is zero or negative when "
299299
"cast to ValueType");
300300
}
301-
return ResidualNorm<value_type>::build()
302-
.with_baseline(this->parameters_.baseline)
303-
.with_reduction_factor(this->parameters_.threshold)
304-
.on(exec)
305-
->generate(cast_args);
301+
result = ResidualNorm<value_type>::build()
302+
.with_baseline(this->parameters_.baseline)
303+
.with_reduction_factor(this->parameters_.threshold)
304+
.on(exec)
305+
->generate(cast_args);
306306
});
307+
return result;
307308
}
308309

309310
residual_norm_factory_parameters parameters_;
@@ -400,12 +401,13 @@ class ImplicitResidualNormFactory
400401
"stopping criterion threshold is zero or negative when "
401402
"cast to ValueType");
402403
}
403-
return ImplicitResidualNorm<value_type>::build()
404-
.with_baseline(this->parameters_.baseline)
405-
.with_reduction_factor(this->parameters_.threshold)
406-
.on(exec)
407-
->generate(cast_args);
404+
result = ImplicitResidualNorm<value_type>::build()
405+
.with_baseline(this->parameters_.baseline)
406+
.with_reduction_factor(this->parameters_.threshold)
407+
.on(exec)
408+
->generate(cast_args);
408409
});
410+
return result;
409411
}
410412

411413
implicit_residual_norm_factory_parameters parameters_;

core/stop/time.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "ginkgo/core/stop/time.hpp"
66

7+
#include <ginkgo/core/base/abstract_factory.hpp>
8+
79

810
namespace gko {
911
namespace stop {
@@ -22,5 +24,19 @@ bool Time::check_impl(uint8 stoppingId, bool setFinalized,
2224
}
2325

2426

27+
deferred_factory_parameter<Time::Factory> time_sec(double time)
28+
{
29+
return Time::build().with_time_limit(
30+
std::chrono::nanoseconds{static_cast<long>(time * 1e9)});
31+
}
32+
33+
34+
deferred_factory_parameter<Time::Factory> time_ms(double time)
35+
{
36+
return Time::build().with_time_limit(
37+
std::chrono::nanoseconds{static_cast<long>(time * 1e6)});
38+
}
39+
40+
2541
} // namespace stop
2642
} // namespace gko

include/ginkgo/core/stop/iteration.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <ginkgo/core/stop/criterion.hpp>
1010

11+
#include "ginkgo/core/base/abstract_factory.hpp"
12+
1113

1214
namespace gko {
1315
namespace stop {
@@ -58,6 +60,9 @@ class Iteration : public EnablePolymorphicObject<Iteration, Criterion> {
5860
};
5961

6062

63+
deferred_factory_parameter<Iteration::Factory> iteration(size_type count);
64+
65+
6166
} // namespace stop
6267
} // namespace gko
6368

include/ginkgo/core/stop/time.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <ginkgo/core/stop/criterion.hpp>
1212

13+
#include "ginkgo/core/base/abstract_factory.hpp"
14+
1315

1416
namespace gko {
1517
namespace stop {
@@ -67,6 +69,10 @@ class Time : public EnablePolymorphicObject<Time, Criterion> {
6769
};
6870

6971

72+
deferred_factory_parameter<Time::Factory> time_sec(double time);
73+
deferred_factory_parameter<Time::Factory> time_ms(double time);
74+
75+
7076
} // namespace stop
7177
} // namespace gko
7278

reference/test/stop/residual_norm_kernels.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,41 @@ TYPED_TEST(ResidualNorm, WaitsTillResidualGoalMultipleRHS)
510510
}
511511

512512

513+
TYPED_TEST(ResidualNorm, SimplifiedInterface)
514+
{
515+
using Mtx = typename TestFixture::Mtx;
516+
using NormVector = typename TestFixture::NormVector;
517+
auto initial_res = gko::initialize<Mtx>({100.0}, this->exec_);
518+
auto initial_guess = gko::initialize<Mtx>({1000.0}, this->exec_);
519+
std::shared_ptr<gko::LinOp> rhs = gko::initialize<Mtx>({10.0}, this->exec_);
520+
521+
auto factory_abs = gko::stop::abs_residual_norm(0.5).on(this->exec_);
522+
auto factory_rel = gko::stop::rel_residual_norm(0.5).on(this->exec_);
523+
auto factory_red = gko::stop::residual_norm_reduction(0.5).on(this->exec_);
524+
525+
auto crit_abs =
526+
gko::as<gko::stop::ResidualNorm<TypeParam>>(factory_abs->generate(
527+
nullptr, rhs, initial_guess.get(), initial_res.get()));
528+
auto crit_rel =
529+
gko::as<gko::stop::ResidualNorm<TypeParam>>(factory_rel->generate(
530+
nullptr, rhs, initial_guess.get(), initial_res.get()));
531+
auto crit_red =
532+
gko::as<gko::stop::ResidualNorm<TypeParam>>(factory_red->generate(
533+
nullptr, rhs, initial_guess.get(), initial_res.get()));
534+
535+
ASSERT_EQ(crit_abs->get_parameters().baseline, gko::stop::mode::absolute);
536+
ASSERT_EQ(crit_rel->get_parameters().baseline, gko::stop::mode::rhs_norm);
537+
ASSERT_EQ(crit_red->get_parameters().baseline,
538+
gko::stop::mode::initial_resnorm);
539+
ASSERT_EQ(crit_abs->get_parameters().reduction_factor,
540+
gko::remove_complex<TypeParam>{0.5});
541+
ASSERT_EQ(crit_rel->get_parameters().reduction_factor,
542+
gko::remove_complex<TypeParam>{0.5});
543+
ASSERT_EQ(crit_red->get_parameters().reduction_factor,
544+
gko::remove_complex<TypeParam>{0.5});
545+
}
546+
547+
513548
template <typename T>
514549
class ResidualNormWithInitialResnorm : public ::testing::Test {
515550
protected:
@@ -964,6 +999,44 @@ TYPED_TEST(ImplicitResidualNorm, WaitsTillResidualGoalMultipleRHS)
964999
}
9651000

9661001

1002+
TYPED_TEST(ImplicitResidualNorm, SimplifiedInterface)
1003+
{
1004+
using Mtx = typename TestFixture::Mtx;
1005+
using NormVector = typename TestFixture::NormVector;
1006+
auto initial_res = gko::initialize<Mtx>({100.0}, this->exec_);
1007+
auto initial_guess = gko::initialize<Mtx>({1000.0}, this->exec_);
1008+
std::shared_ptr<gko::LinOp> rhs = gko::initialize<Mtx>({10.0}, this->exec_);
1009+
1010+
auto factory_abs =
1011+
gko::stop::implicit_abs_residual_norm(0.5).on(this->exec_);
1012+
auto factory_rel =
1013+
gko::stop::implicit_rel_residual_norm(0.5).on(this->exec_);
1014+
auto factory_red =
1015+
gko::stop::implicit_residual_norm_reduction(0.5).on(this->exec_);
1016+
1017+
auto crit_abs = gko::as<gko::stop::ImplicitResidualNorm<TypeParam>>(
1018+
factory_abs->generate(nullptr, rhs, initial_guess.get(),
1019+
initial_res.get()));
1020+
auto crit_rel = gko::as<gko::stop::ImplicitResidualNorm<TypeParam>>(
1021+
factory_rel->generate(nullptr, rhs, initial_guess.get(),
1022+
initial_res.get()));
1023+
auto crit_red = gko::as<gko::stop::ImplicitResidualNorm<TypeParam>>(
1024+
factory_red->generate(nullptr, rhs, initial_guess.get(),
1025+
initial_res.get()));
1026+
1027+
ASSERT_EQ(crit_abs->get_parameters().baseline, gko::stop::mode::absolute);
1028+
ASSERT_EQ(crit_rel->get_parameters().baseline, gko::stop::mode::rhs_norm);
1029+
ASSERT_EQ(crit_red->get_parameters().baseline,
1030+
gko::stop::mode::initial_resnorm);
1031+
ASSERT_EQ(crit_abs->get_parameters().reduction_factor,
1032+
gko::remove_complex<TypeParam>{0.5});
1033+
ASSERT_EQ(crit_rel->get_parameters().reduction_factor,
1034+
gko::remove_complex<TypeParam>{0.5});
1035+
ASSERT_EQ(crit_red->get_parameters().reduction_factor,
1036+
gko::remove_complex<TypeParam>{0.5});
1037+
}
1038+
1039+
9671040
template <typename T>
9681041
class ResidualNormWithAbsolute : public ::testing::Test {
9691042
protected:

test/mpi/solver/solver.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,8 @@ struct SimpleSolverTest {
7979
std::shared_ptr<const gko::Executor> exec)
8080
{
8181
return solver_type::build().with_criteria(
82-
gko::stop::Iteration::build().with_max_iters(iteration_count()),
83-
gko::stop::ResidualNorm<value_type>::build()
84-
.with_baseline(gko::stop::mode::absolute)
85-
.with_reduction_factor(reduction_factor())
86-
.on(exec));
82+
gko::stop::iteration(iteration_count()),
83+
gko::stop::abs_residual_norm(reduction_factor()).on(exec));
8784
}
8885

8986
static void assert_empty_state(const solver_type* mtx)

test/solver/solver.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@ struct SimpleSolverTest {
7272
std::shared_ptr<const gko::Executor> exec,
7373
gko::size_type iteration_count, bool check_residual = true)
7474
{
75-
return solver_type::build().with_criteria(
76-
gko::stop::Iteration::build().with_max_iters(iteration_count),
77-
check_residual ? gko::stop::ResidualNorm<value_type>::build()
78-
.with_baseline(gko::stop::mode::absolute)
79-
.with_reduction_factor(1e-30)
80-
.on(exec)
81-
: nullptr);
75+
if (check_residual) {
76+
return solver_type::build().with_criteria(
77+
gko::stop::iteration(iteration_count));
78+
} else {
79+
return solver_type::build().with_criteria(
80+
gko::stop::iteration(iteration_count),
81+
gko::stop::abs_residual_norm(1e-30));
82+
}
8283
}
8384

8485
static typename solver_type::parameters_type build_preconditioned(

0 commit comments

Comments
 (0)