diff --git a/adelie/src/include/adelie_core/solver/solver_base.hpp b/adelie/src/include/adelie_core/solver/solver_base.hpp index 1fb3bc90..6ba11dc6 100644 --- a/adelie/src/include/adelie_core/solver/solver_base.hpp +++ b/adelie/src/include/adelie_core/solver/solver_base.hpp @@ -24,6 +24,7 @@ inline void update_abs_grad( ) { using state_t = std::decay_t; + using value_t = typename state_t::value_t; using vec_value_t = typename state_t::vec_value_t; using rowmat_uint64_t = util::rowmat_type; @@ -94,7 +95,12 @@ inline void update_abs_grad( try_failed = true; } }; - util::omp_parallel_for(routine, 0, groups.size(), n_threads); + const bool is_not_all_none = util::rowvec_type::NullaryExpr( + constraints.size(), + [&](auto i) { return constraints[i] != nullptr; } + ).any(); + const size_t n_bytes = sizeof(value_t) * abs_grad.size(); + util::omp_parallel_for(routine, 0, groups.size(), n_threads * (is_not_all_none || (n_bytes > Configs::min_bytes))); if (try_failed) { throw util::adelie_core_solver_error( "exception raised in constraint->solve_zero(). " @@ -157,6 +163,8 @@ inline auto sparsify_dual( VecValueType& values ) { + using index_t = typename StateType::index_t; + using value_t = typename StateType::value_t; using vec_index_t = typename StateType::vec_index_t; using vec_value_t = typename StateType::vec_value_t; using sp_vec_value_t = typename StateType::sp_vec_value_t; @@ -192,7 +200,12 @@ inline auto sparsify_dual( constraint->dual(indices_v, values_v); indices_v += dual_groups[i]; }; - util::omp_parallel_for(routine, 0, n_constraints, n_threads); + const bool is_not_all_none = util::rowvec_type::NullaryExpr( + constraints.size(), + [&](auto i) { return constraints[i] != nullptr; } + ).any(); + const size_t n_bytes = (sizeof(index_t) + sizeof(value_t)) * indices.size(); + util::omp_parallel_for(routine, 0, n_constraints, n_threads * (is_not_all_none || (n_bytes > Configs::min_bytes))); } const auto last_constraint = constraints[n_constraints-1];