Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "oneapi/dal/backend/interop/table_conversion.hpp"

#include "oneapi/dal/backend/primitives/ndarray.hpp"
#include "oneapi/dal/backend/primitives/utils.hpp"

#include "oneapi/dal/table/row_accessor.hpp"

Expand All @@ -31,6 +32,8 @@
#include "oneapi/dal/algo/linear_regression/backend/model_impl.hpp"
#include "oneapi/dal/algo/linear_regression/backend/cpu/train_kernel.hpp"
#include "oneapi/dal/algo/linear_regression/backend/cpu/train_kernel_common.hpp"
#include "oneapi/dal/algo/linear_regression/backend/cpu/partial_train_kernel.hpp"
#include "oneapi/dal/algo/linear_regression/backend/cpu/finalize_train_kernel.hpp"

namespace oneapi::dal::linear_regression::backend {

Expand All @@ -54,6 +57,106 @@ using batch_lr_kernel_t = daal_lr::training::internal::BatchKernel<Float, daal_l
template <typename Float, daal::CpuType Cpu>
using batch_rr_kernel_t = daal_rr::training::internal::BatchKernel<Float, daal_rr_method, Cpu>;

template <typename Float, typename Task>
static train_result<Task> call_daal_spmd_kernel(const context_cpu& ctx,
const detail::descriptor_base<Task>& desc,
const detail::train_parameters<Task>& params,
const table& data,
const table& resp) {
auto& comm = ctx.get_communicator();

/// Compute partial X^T * X and X^T * y on each rank
partial_train_input<Task> partial_input(data, resp);
auto partial_result =
dal::linear_regression::backend::partial_train_kernel_cpu<Float, method::norm_eq, Task>{}(
ctx,
desc,
params,
partial_input);

/// Get local partial X^T * X and X^T * y as array<Float> to pass to collective allgatherv
const auto& xtx_local = partial_result.get_partial_xtx();
const auto& xty_local = partial_result.get_partial_xty();
const auto xtx_local_nd = pr::table2ndarray<Float>(xtx_local);
const auto xty_local_nd = pr::table2ndarray<Float>(xty_local);
const auto xtx_local_ary =
dal::array<Float>::wrap(xtx_local_nd.get_data(), xtx_local_nd.get_count());
const auto xty_local_ary =
dal::array<Float>::wrap(xty_local_nd.get_data(), xty_local_nd.get_count());

/// Allocate storage for gathered X^T * X and X^T * y across all ranks
auto rank_count = comm.get_rank_count();
const std::int64_t ext_feature_count = xtx_local.get_row_count();
const std::int64_t response_count = xty_local.get_row_count();
auto xtx_gathered_ary =
dal::array<Float>::empty(ext_feature_count * ext_feature_count * rank_count);
auto xty_gathered_ary =
dal::array<Float>::empty(response_count * ext_feature_count * rank_count);
/// Received counts of elements in X^T * X and X^T * y for each rank
std::vector<std::int64_t> xtx_recv_counts_ary(rank_count,
ext_feature_count * ext_feature_count);
std::vector<std::int64_t> xty_recv_counts_ary(rank_count, response_count * ext_feature_count);
/// Displacements of X^T * X and X^T * y in the gathered arrays for each rank
/// Note: All ranks have the same size of X^T * X and X^T * y
std::vector<std::int64_t> xtx_displs_ary(rank_count);
std::vector<std::int64_t> xty_displs_ary(rank_count);
for (std::int64_t i = 0; i < rank_count; i++) {
xtx_displs_ary[i] = i * ext_feature_count * ext_feature_count;
xty_displs_ary[i] = i * response_count * ext_feature_count;
}

/// Collectively gather X^T * X and X^T * y across all ranks
comm.allgatherv(xtx_local_ary,
xtx_gathered_ary,
xtx_recv_counts_ary.data(),
xtx_displs_ary.data());
comm.allgatherv(xty_local_ary,
xty_gathered_ary,
xty_recv_counts_ary.data(),
xty_displs_ary.data());

/// Sum up the gathered X^T * X and X^T * y across all ranks
/// Note: DAAL has a kernel for this step:
/// daal::algorithms::linear_regression::training::internal::DistributedKernel
/// But the logic in that kernel is very simple,
/// so it is more efficient to implement it right here than to convert inputs and outputs
/// and call DAAL kernel.
auto xtx_ary = dal::array<Float>::zeros(ext_feature_count * ext_feature_count);
auto xty_ary = dal::array<Float>::zeros(ext_feature_count * response_count);
const Float* xtx_gathered = xtx_gathered_ary.get_data();
const Float* xty_gathered = xty_gathered_ary.get_data();
Float* xtx = xtx_ary.get_mutable_data();
Float* xty = xty_ary.get_mutable_data();

for (std::int64_t r = 0; r < rank_count; ++r) {
const Float* xtx_gathered_r = xtx_gathered + xtx_displs_ary[r];
for (std::int64_t i = 0; i < xtx_recv_counts_ary[r]; ++i) {
xtx[i] += xtx_gathered_r[i];
}
const Float* xty_gathered_r = xty_gathered + xty_displs_ary[r];
for (std::int64_t i = 0; i < xty_recv_counts_ary[r]; ++i) {
xty[i] += xty_gathered_r[i];
}
}

/// Wrap the gathered X^T * X and X^T * y into homogen tables
auto xtx_table = homogen_table::wrap(xtx_ary, ext_feature_count, ext_feature_count);
auto xty_table = homogen_table::wrap(xty_ary, response_count, ext_feature_count);

/// Compute regression coefficients
partial_train_result<Task> partial_result_final;
partial_result_final.set_partial_xtx(xtx_table);
partial_result_final.set_partial_xty(xty_table);
auto result =
dal::linear_regression::backend::finalize_train_kernel_cpu<Float, method::norm_eq, Task>{}(
ctx,
desc,
params,
partial_result_final);

return result;
}

template <typename Float, typename Task>
static train_result<Task> call_daal_kernel(const context_cpu& ctx,
const detail::descriptor_base<Task>& desc,
Expand Down Expand Up @@ -171,6 +274,13 @@ static train_result<Task> train(const context_cpu& ctx,
const detail::descriptor_base<Task>& desc,
const detail::train_parameters<Task>& params,
const train_input<Task>& input) {
if (ctx.get_communicator().get_rank_count() > 1) {
return call_daal_spmd_kernel<Float, Task>(ctx,
desc,
params,
input.get_data(),
input.get_responses());
}
return call_daal_kernel<Float, Task>(ctx,
desc,
params,
Expand Down
4 changes: 2 additions & 2 deletions cpp/oneapi/dal/algo/linear_regression/detail/train_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct train_ops_dispatcher<Policy, Float, Method, Task> {
const descriptor_base<Task>& desc,
const train_input<Task>& input) const {
using kernel_dispatcher_t = dal::backend::kernel_dispatcher< //
KERNEL_SINGLE_NODE_CPU(parameters::train_parameters_cpu<Float, Method, Task>)>;
KERNEL_UNIVERSAL_SPMD_CPU(parameters::train_parameters_cpu<Float, Method, Task>)>;
return kernel_dispatcher_t{}(ctx, desc, input);
}

Expand All @@ -54,7 +54,7 @@ struct train_ops_dispatcher<Policy, Float, Method, Task> {
const train_parameters<Task>& params,
const train_input<Task>& input) const {
using kernel_dispatcher_t = dal::backend::kernel_dispatcher< //
KERNEL_SINGLE_NODE_CPU(backend::train_kernel_cpu<Float, Method, Task>)>;
KERNEL_UNIVERSAL_SPMD_CPU(backend::train_kernel_cpu<Float, Method, Task>)>;
return kernel_dispatcher_t{}(ctx, desc, params, input);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct train_ops_dispatcher<Policy, Float, Method, Task> {
const descriptor_base<Task>& desc,
const train_input<Task>& input) const {
using kernel_dispatcher_t = dal::backend::kernel_dispatcher<
KERNEL_SINGLE_NODE_CPU(parameters::train_parameters_cpu<Float, Method, Task>),
KERNEL_UNIVERSAL_SPMD_CPU(parameters::train_parameters_cpu<Float, Method, Task>),
KERNEL_UNIVERSAL_SPMD_GPU(parameters::train_parameters_gpu<Float, Method, Task>)>;
return kernel_dispatcher_t{}(ctx, desc, input);
}
Expand All @@ -55,7 +55,7 @@ struct train_ops_dispatcher<Policy, Float, Method, Task> {
const train_parameters<Task>& params,
const train_input<Task>& input) const {
using kernel_dispatcher_t = dal::backend::kernel_dispatcher<
KERNEL_SINGLE_NODE_CPU(backend::train_kernel_cpu<Float, Method, Task>),
KERNEL_UNIVERSAL_SPMD_CPU(backend::train_kernel_cpu<Float, Method, Task>),
KERNEL_UNIVERSAL_SPMD_GPU(backend::train_kernel_gpu<Float, Method, Task>)>;
return kernel_dispatcher_t{}(ctx, desc, params, input);
}
Expand Down
1 change: 0 additions & 1 deletion cpp/oneapi/dal/algo/linear_regression/test/spmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
namespace oneapi::dal::linear_regression::test {

TEMPLATE_LIST_TEST_M(lr_spmd_test, "LR common flow", "[lr][spmd]", lr_types) {
SKIP_IF(this->get_policy().is_cpu());
SKIP_IF(this->not_float64_friendly());

this->generate(777);
Expand Down
78 changes: 78 additions & 0 deletions cpp/oneapi/dal/backend/dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#define KERNEL_SINGLE_NODE_CPU(...) \
KERNEL_SPEC(::oneapi::dal::backend::single_node_cpu_kernel, __VA_ARGS__)

#define KERNEL_UNIVERSAL_SPMD_CPU(...) \
KERNEL_SPEC(::oneapi::dal::backend::universal_spmd_cpu_kernel, __VA_ARGS__)

#define KERNEL_SINGLE_NODE_GPU(...) \
KERNEL_SPEC(::oneapi::dal::backend::single_node_gpu_kernel, __VA_ARGS__)

Expand Down Expand Up @@ -152,6 +155,9 @@ inline auto dispatch_by_device(const detail::data_parallel_policy& policy,
/// Tag that indicates CPU kernel for single-node
struct single_node_cpu_kernel {};

/// Tag that indicates universal CPU kernel for single-node and SPMD modes
struct universal_spmd_cpu_kernel {};

/// Tag that indicates GPU kernel for single-node
struct single_node_gpu_kernel {};

Expand Down Expand Up @@ -223,6 +229,48 @@ struct kernel_dispatcher<kernel_spec<single_node_cpu_kernel, CpuKernel>> {
#endif
};

/// Dispatcher for the case of multi-node CPU algorithm based on universal SPMD kernel
template <typename CpuKernel>
struct kernel_dispatcher<kernel_spec<universal_spmd_cpu_kernel, CpuKernel>> {
template <typename... Args>
auto operator()(const detail::host_policy& policy, Args&&... args) const {
return CpuKernel{}(context_cpu{}, std::forward<Args>(args)...);
}

template <typename... Args>
auto operator()(const detail::spmd_host_policy& policy, Args&&... args) const {
return CpuKernel{}(context_cpu{ policy }, std::forward<Args>(args)...);
}

#ifdef ONEDAL_DATA_PARALLEL
template <typename... Args>
auto operator()(const detail::data_parallel_policy& policy, Args&&... args) const {
return dispatch_by_device(
policy,
[&]() {
return CpuKernel{}(context_cpu{}, std::forward<Args>(args)...);
},
[&]() -> cpu_kernel_return_t<CpuKernel, Args...> {
// We have to specify return type for this lambda as compiler cannot
// infer it from a body that consist of single `throw` expression
using msg = detail::error_messages;
throw unimplemented{ msg::algorithm_is_not_implemented_for_this_device() };
});
}
#endif

#ifdef ONEDAL_DATA_PARALLEL
template <typename... Args>
auto operator()(const detail::spmd_data_parallel_policy& policy, Args&&... args) const
-> cpu_kernel_return_t<CpuKernel, Args...> {
// We have to specify return type for this function as compiler cannot
// infer it from a body that consist of single `throw` expression
using msg = detail::error_messages;
throw unimplemented{ msg::spmd_version_of_algorithm_is_not_implemented() };
}
#endif
};

#ifdef ONEDAL_DATA_PARALLEL
/// Dispatcher for the case of single-node CPU and GPU algorithm
template <typename CpuKernel, typename GpuKernel>
Expand Down Expand Up @@ -286,6 +334,36 @@ struct kernel_dispatcher<kernel_spec<single_node_cpu_kernel, CpuKernel>,
});
}
};

/// Dispatcher for the case of multi-node CPU algorithm based on universal SPMD kernel and
/// multi-node GPU algorithm based on universal SPMD kernel
template <typename CpuKernel, typename GpuKernel>
struct kernel_dispatcher<kernel_spec<universal_spmd_cpu_kernel, CpuKernel>,
kernel_spec<universal_spmd_gpu_kernel, GpuKernel>> {
template <typename... Args>
auto operator()(const detail::data_parallel_policy& policy, Args&&... args) const {
return dispatch_by_device(
policy,
[&]() {
return CpuKernel{}(context_cpu{ policy }, std::forward<Args>(args)...);
},
[&]() {
return GpuKernel{}(context_gpu{ policy }, std::forward<Args>(args)...);
});
}

template <typename... Args>
auto operator()(const detail::spmd_data_parallel_policy& policy, Args&&... args) const {
return dispatch_by_device(
policy.get_local(),
[&]() {
return CpuKernel{}(context_cpu{ policy }, std::forward<Args>(args)...);
},
[&]() {
return GpuKernel{}(context_gpu{ policy }, std::forward<Args>(args)...);
});
}
};
#endif

inline bool test_cpu_extension(detail::cpu_extension mask, detail::cpu_extension test) {
Expand Down
Loading