Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/src/pca/pca.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ void pcaFit(const raft::handle_t& handle,
raft::matrix::seqRoot(explained_var, singular_vals, scalar, n_components, stream, true);

raft::stats::meanAdd<false, true>(input, input, mu, prms.n_cols, prms.n_rows, stream);

signFlipComponents(components, prms.n_components, prms.n_cols, stream);
}

/**
Expand Down Expand Up @@ -180,7 +182,6 @@ void pcaFitTransform(const raft::handle_t& handle,
prms,
stream);
pcaTransform(handle, input, components, trans_input, singular_vals, mu, prms, stream);
signFlip(trans_input, prms.n_rows, prms.n_components, components, prms.n_cols, stream);
}

// TODO: implement pcaGetCovariance function
Expand Down
7 changes: 3 additions & 4 deletions cpp/src/pca/pca_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

#include <cuml/decomposition/pca.hpp>
#include <cuml/decomposition/pca_mg.hpp>
#include <cuml/decomposition/sign_flip_mg.hpp>

#include <cumlprims/opg/linalg/qr_based_svd.hpp>
#include <cumlprims/opg/matrix/matrix_utils.hpp>
Expand Down Expand Up @@ -75,6 +74,8 @@ void fit_impl(raft::handle_t& handle,
raft::matrix::seqRoot(explained_var, singular_vals, scalar, prms.n_components, streams[0], true);

Stats::opg::mean_add(input_data, input_desc, mu_data, comm, streams, n_streams);

signFlipComponents(components, prms.n_components, prms.n_cols, streams[0]);
}

/**
Expand Down Expand Up @@ -164,7 +165,7 @@ void fit_impl(raft::handle_t& handle,
rank);

// sign flip
sign_flip(handle, uMatrixParts, input_desc, vMatrix.data(), prms.n_cols, streams, n_streams);
signFlipComponents(vMatrix.data(), prms.n_cols, prms.n_cols, stream);

// Calculate instance variables
rmm::device_uvector<T> explained_var_all(prms.n_cols, stream);
Expand Down Expand Up @@ -539,8 +540,6 @@ void fit_transform_impl(raft::handle_t& handle,
n_streams,
verbose);

sign_flip(handle, trans_data, input_desc, components, prms.n_components, streams, n_streams);

for (std::uint32_t i = 0; i < n_streams; i++) {
handle.sync_stream(streams[i]);
}
Expand Down
84 changes: 82 additions & 2 deletions cpp/src/tsvd/tsvd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,86 @@ void calEig(const raft::handle_t& handle,
raft::matrix::rowReverse(explained_var, prms.n_cols, std::size_t(1), stream);
}

namespace detail {

template <typename T>
struct AbsMaxOp {
__device__ __host__ inline T operator()(T a, T b) const {
T abs_a = a >= 0 ? a : -a;
T abs_b = b >= 0 ? b : -b;
return abs_a > abs_b ? a : b; // return signed value of the larger |x|
}
};

template <typename T>
struct FlipOp {
T* components;
T* max_vals;
std::size_t n_rows;
std::size_t n_cols;

__device__ void operator()(std::size_t row) {
if (max_vals[row] < T(0)) {
// Flip corresponding components
for (std::size_t i = row; i < n_rows * n_cols; i += n_rows) {
components[i] = -components[i];
}
}
}
};

} // namespace detail

/**
* @defgroup sign flip for PCA and tSVD. This is used to stabilize the sign of column major eigen
* vectors
* @param components: components matrix, used to determine the sign of max absolute value
* @param n_rows: number of rows of components matrix
* @param n_cols: number of columns of components matrix
* @param stream cuda stream
* @{
*/
template <typename math_t>
void signFlipComponents(math_t* components,
std::size_t n_rows,
std::size_t n_cols,
cudaStream_t stream)
{
rmm::device_uvector<math_t> max_vals(n_rows, stream);

// Step 1: find component-wise max absolute values
raft::linalg::reduce<
true, // rowMajor
false, // alongRows
math_t, // InType
math_t, // OutType
std::size_t, // IdxType
raft::identity_op, // MainLambda
detail::AbsMaxOp<math_t>, // ReduceLambda
raft::identity_op // FinalLambda
>(
max_vals.data(), // OutType* out
components, // InType const* in
n_rows, // rows
n_cols, // cols
math_t(0), // init value
stream, // cudaStream_t
true, // inclusive (can be true)
raft::identity_op(), // main_op
detail::AbsMaxOp<math_t>(), // reduce_op
raft::identity_op() // final_op
);

// Step 2: flip rows where needed
detail::FlipOp<math_t> op{components, max_vals.data(), n_rows, n_cols};
thrust::for_each(
rmm::exec_policy(stream),
thrust::make_counting_iterator<std::size_t>(0),
thrust::make_counting_iterator<std::size_t>(n_rows),
op
);
}

/**
* @defgroup sign flip for PCA and tSVD. This is used to stabilize the sign of column major eigen
* vectors
Expand Down Expand Up @@ -235,6 +315,8 @@ void tsvdFit(const raft::handle_t& handle,

math_t scalar = math_t(1);
raft::matrix::seqRoot(explained_var_all.data(), singular_vals, scalar, n_components, stream);

signFlipComponents(components, prms.n_components, prms.n_cols, stream);
}

/**
Expand Down Expand Up @@ -267,8 +349,6 @@ void tsvdFitTransform(const raft::handle_t& handle,
tsvdFit(handle, input, components, singular_vals, prms, stream);
tsvdTransform(handle, input, components, trans_input, prms, stream);

signFlip(trans_input, prms.n_rows, prms.n_components, components, prms.n_cols, stream);

rmm::device_uvector<math_t> mu_trans(prms.n_components, stream);
raft::stats::mean<false>(
mu_trans.data(), trans_input, prms.n_components, prms.n_rows, false, stream);
Expand Down
209 changes: 209 additions & 0 deletions python/cuml/tests/test_pca_sign_flip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import numpy as np
import pandas as pd

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA as skPCA
from cuml.decomposition import PCA as cuPCA

def test_pca_sign_flip():
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=13)
sk_pca = skPCA(n_components=10, random_state=13)
sk_pca_res = sk_pca.fit_transform(X)
print(sk_pca_res)
# [[ 1.16014257e+03 -2.93917544e+02 4.85783976e+01 ... 1.48167035e-01
# -7.45463100e-01 5.89359489e-01]
# [ 1.26912244e+03 1.56301818e+01 -3.53945342e+01 ... 2.00803920e-01
# 4.85827948e-01 -8.40347189e-02]
# [ 9.95793889e+02 3.91567432e+01 -1.70975298e+00 ... -2.74025583e-01
# 1.73874273e-01 -1.86993508e-01]
# ...
# [ 3.14501756e+02 4.75535252e+01 -1.04424072e+01 ... -4.42278787e-01
# 9.73984733e-02 -1.44667285e-01]
# [ 1.12485812e+03 3.41292250e+01 -1.97420874e+01 ... -3.59964104e-01
# -3.85030186e-01 6.15467490e-01]
# [-7.71527622e+02 -8.86431064e+01 2.38890319e+01 ... 3.00390899e-02
# 4.23451051e-01 -3.01438975e-01]]
sk_pca_comp = sk_pca.components_
print(sk_pca_comp)
# [[ 5.08623202e-03 2.19657026e-03 3.50763298e-02 5.16826469e-01
# 4.23694535e-06 4.05260047e-05 8.19399539e-05 4.77807775e-05
# 7.07804332e-06 -2.62155251e-06 3.13742507e-04 -6.50984008e-05
# 2.23634150e-03 5.57271669e-02 -8.05646029e-07 5.51918197e-06
# 8.87094462e-06 3.27915009e-06 -1.24101836e-06 -8.54530832e-08
# 7.15473257e-03 3.06736622e-03 4.94576447e-02 8.52063392e-01
# 6.42005481e-06 1.01275937e-04 1.68928625e-04 7.36658178e-05
# 1.78986262e-05 1.61356159e-06]
# [ 9.28705650e-03 -2.88160658e-03 6.27480827e-02 8.51823720e-01
# -1.48194356e-05 -2.68862249e-06 7.51419574e-05 4.63501038e-05
# -2.52430431e-05 -1.61197148e-05 -5.38692831e-05 3.48370414e-04
# 8.19640791e-04 7.51112451e-03 1.49438131e-06 1.27357957e-05
# 2.86921009e-05 9.36007477e-06 1.22647432e-05 2.89683790e-07
# -5.68673345e-04 -1.32152605e-02 -1.85961117e-04 -5.19742358e-01
# -7.68565692e-05 -2.56104144e-04 -1.75471479e-04 -3.05051743e-05
# -1.57042845e-04 -5.53071662e-05]
# [-1.23425821e-02 -6.35497857e-03 -7.16694814e-02 -2.78944181e-02
# 7.26596827e-05 1.01754350e-04 2.65989729e-04 3.60471764e-05
# 1.41290958e-04 5.06376971e-05 6.06156709e-03 6.23377635e-03
# 4.38560369e-02 9.90245878e-01 4.34471433e-05 1.27658711e-04
# 2.07365800e-04 4.78855144e-05 1.14411270e-04 2.43158370e-05
# -1.55659935e-02 -3.15446196e-02 -9.23133791e-02 -3.93186778e-02
# -4.21307399e-05 -7.64833237e-04 -8.46552237e-04 -3.33596393e-04
# -3.49992952e-04 -4.09371692e-05]
# [ 3.42380473e-02 3.62415111e-01 3.29281417e-01 -3.94122494e-02
# 3.44153009e-04 3.00489873e-03 3.40779110e-03 1.24725032e-03
# 9.66809714e-04 1.99194796e-04 4.08618843e-03 2.26398666e-02
# 4.98565303e-02 1.01980275e-01 -6.69114619e-06 8.93263012e-04
# 9.95328878e-04 2.34560912e-04 1.24528446e-04 6.72412843e-05
# 6.18999387e-02 5.42057412e-01 6.66816451e-01 -3.87691524e-02
# 7.21927589e-04 1.03619614e-02 1.15618071e-02 2.99467373e-03
# 2.64085170e-03 9.08697292e-04]
# [-3.54561138e-02 4.43187450e-01 -3.13382893e-01 4.60378117e-02
# -5.79019359e-04 -2.52639926e-03 -2.19520726e-03 -1.13196737e-03
# -9.37014169e-04 -2.07028041e-04 -2.93386180e-03 3.75434531e-02
# -3.57275320e-02 -5.08045702e-02 5.18037664e-05 -5.24579915e-04
# -5.76839903e-04 -2.25598524e-04 -6.11321955e-05 -4.64421630e-05
# -5.31447667e-02 6.12574312e-01 -5.64102976e-01 1.84525531e-02
# -4.65062512e-04 -6.09647380e-03 -6.16530214e-03 -2.41157233e-03
# -1.88324182e-03 -5.19581269e-04]
# [ 1.31213101e-01 2.13486089e-01 8.40324225e-01 -5.23468101e-02
# 4.06502430e-04 1.01527758e-03 -2.75600070e-04 5.76346878e-04
# 1.79444495e-04 -2.19983885e-04 8.45585552e-04 1.24013980e-02
# -9.48056397e-02 2.31166662e-02 1.49989514e-05 3.59930492e-04
# 3.83840527e-04 4.25616208e-04 4.11711911e-05 1.00135535e-04
# 7.49807186e-02 -1.21167279e-01 -4.44630524e-01 2.01806772e-02
# -1.47871511e-03 -9.48569782e-03 -1.04511092e-02 -1.59681971e-03
# -5.47852368e-03 -1.23726579e-03]
# [-3.35131912e-02 7.84253475e-01 -1.89074737e-01 7.33787337e-03
# -1.60796958e-03 -2.77107786e-04 -1.02365525e-03 -9.05454729e-04
# -5.98298140e-04 4.25619565e-05 1.53826412e-02 -6.66867308e-02
# 1.48548561e-01 -2.25977534e-02 2.37177063e-04 1.27405510e-03
# 1.41036865e-03 5.21614746e-04 7.13773114e-04 1.94572501e-04
# -4.53901657e-02 -5.52024144e-01 1.17015608e-01 -1.83169390e-03
# -3.94704099e-03 -7.74390329e-03 -1.08822097e-02 -4.24156865e-03
# -7.03799527e-03 -1.17067750e-03]
# [-7.54924586e-02 -6.87405638e-02 8.39642267e-02 -3.00992471e-03
# 3.43658580e-03 1.55731486e-02 1.92512587e-02 9.07295722e-03
# 9.14981254e-03 3.00903313e-03 8.41264091e-02 5.87281487e-01
# 7.77894165e-01 -4.22340772e-02 1.60686058e-03 8.68045428e-03
# 1.17908464e-02 3.99536083e-03 5.54864404e-03 1.44055227e-03
# -1.29429315e-01 1.60158677e-02 -7.32396132e-02 5.01077221e-03
# -2.13704578e-03 -8.59643464e-03 -5.96017341e-03 4.06900451e-04
# -7.69020237e-03 -9.48038273e-05]
# [ 3.50549264e-01 -4.08376429e-03 -1.32828034e-01 3.82916116e-03
# -8.22698130e-03 -5.63148308e-02 -7.02297025e-02 -1.92498100e-02
# -1.49895864e-02 -7.63859418e-03 9.15127909e-02 -1.62432583e-02
# 1.90207008e-01 -4.47806986e-03 -1.02379224e-03 -1.89362797e-02
# -2.52133644e-02 -2.90414428e-03 -1.76381905e-03 -2.10548292e-03
# 8.60507446e-01 2.11047340e-03 -3.86007208e-02 -4.09978696e-03
# -1.14296358e-02 -1.56279540e-01 -1.91177497e-01 -3.09293509e-02
# -2.02890503e-02 -1.76591684e-02]
# [ 1.39559852e-01 7.66679112e-02 -8.92113884e-02 1.95571373e-03
# -4.44685266e-03 -2.99475404e-02 -2.79441150e-02 -1.04362500e-02
# -8.27800166e-03 -4.54280517e-03 -3.22073955e-02 7.91295671e-01
# -5.48501632e-01 2.03427164e-02 4.45191329e-04 -6.64830167e-03
# -9.88467883e-03 -2.63444827e-03 1.06193405e-03 -6.29812958e-04
# 2.83270489e-02 -9.44156308e-02 8.41866696e-02 -3.10937021e-03
# -1.17730197e-02 -8.54753911e-02 -1.00411922e-01 -2.84427473e-02
# -2.97280861e-02 -1.21927540e-02]]
cu_pca = cuPCA(n_components=10)
cu_pca_res = cu_pca.fit_transform(X).to_numpy()
print(cu_pca_res)
# [[ 1.16014257e+03 2.93917544e+02 4.85783976e+01 ... 1.48167035e-01
# 7.45463100e-01 -5.89359489e-01]
# [ 1.26912244e+03 -1.56301818e+01 -3.53945342e+01 ... 2.00803920e-01
# -4.85827948e-01 8.40347188e-02]
# [ 9.95793889e+02 -3.91567432e+01 -1.70975298e+00 ... -2.74025583e-01
# -1.73874273e-01 1.86993508e-01]
# ...
# [ 3.14501756e+02 -4.75535252e+01 -1.04424072e+01 ... -4.42278787e-01
# -9.73984733e-02 1.44667284e-01]
# [ 1.12485812e+03 -3.41292250e+01 -1.97420874e+01 ... -3.59964104e-01
# 3.85030186e-01 -6.15467490e-01]
# [-7.71527622e+02 8.86431064e+01 2.38890319e+01 ... 3.00390899e-02
# -4.23451051e-01 3.01438975e-01]]
cu_pca_comp = cu_pca.components_.to_numpy()
print(cu_pca_comp)
# [[ 5.08623202e-03 2.19657026e-03 3.50763298e-02 5.16826469e-01
# 4.23694535e-06 4.05260047e-05 8.19399539e-05 4.77807775e-05
# 7.07804332e-06 -2.62155251e-06 3.13742507e-04 -6.50984008e-05
# 2.23634150e-03 5.57271669e-02 -8.05646029e-07 5.51918197e-06
# 8.87094462e-06 3.27915009e-06 -1.24101836e-06 -8.54530832e-08
# 7.15473257e-03 3.06736622e-03 4.94576447e-02 8.52063392e-01
# 6.42005481e-06 1.01275937e-04 1.68928625e-04 7.36658178e-05
# 1.78986262e-05 1.61356159e-06]
# [-9.28705650e-03 2.88160658e-03 -6.27480827e-02 -8.51823720e-01
# 1.48194356e-05 2.68862249e-06 -7.51419574e-05 -4.63501038e-05
# 2.52430431e-05 1.61197148e-05 5.38692831e-05 -3.48370414e-04
# -8.19640791e-04 -7.51112451e-03 -1.49438131e-06 -1.27357957e-05
# -2.86921009e-05 -9.36007477e-06 -1.22647432e-05 -2.89683790e-07
# 5.68673345e-04 1.32152605e-02 1.85961117e-04 5.19742358e-01
# 7.68565692e-05 2.56104144e-04 1.75471479e-04 3.05051743e-05
# 1.57042845e-04 5.53071662e-05]
# [-1.23425821e-02 -6.35497857e-03 -7.16694814e-02 -2.78944181e-02
# 7.26596827e-05 1.01754350e-04 2.65989729e-04 3.60471764e-05
# 1.41290958e-04 5.06376971e-05 6.06156709e-03 6.23377635e-03
# 4.38560369e-02 9.90245878e-01 4.34471433e-05 1.27658711e-04
# 2.07365800e-04 4.78855144e-05 1.14411270e-04 2.43158370e-05
# -1.55659935e-02 -3.15446196e-02 -9.23133791e-02 -3.93186778e-02
# -4.21307399e-05 -7.64833237e-04 -8.46552237e-04 -3.33596393e-04
# -3.49992952e-04 -4.09371692e-05]
# [ 3.42380473e-02 3.62415111e-01 3.29281417e-01 -3.94122494e-02
# 3.44153009e-04 3.00489873e-03 3.40779110e-03 1.24725032e-03
# 9.66809714e-04 1.99194796e-04 4.08618843e-03 2.26398666e-02
# 4.98565303e-02 1.01980275e-01 -6.69114619e-06 8.93263012e-04
# 9.95328878e-04 2.34560912e-04 1.24528446e-04 6.72412843e-05
# 6.18999387e-02 5.42057412e-01 6.66816451e-01 -3.87691524e-02
# 7.21927589e-04 1.03619614e-02 1.15618071e-02 2.99467373e-03
# 2.64085170e-03 9.08697292e-04]
# [-3.54561138e-02 4.43187450e-01 -3.13382893e-01 4.60378117e-02
# -5.79019359e-04 -2.52639926e-03 -2.19520726e-03 -1.13196737e-03
# -9.37014169e-04 -2.07028041e-04 -2.93386180e-03 3.75434531e-02
# -3.57275320e-02 -5.08045702e-02 5.18037664e-05 -5.24579915e-04
# -5.76839903e-04 -2.25598524e-04 -6.11321955e-05 -4.64421630e-05
# -5.31447667e-02 6.12574312e-01 -5.64102976e-01 1.84525531e-02
# -4.65062512e-04 -6.09647380e-03 -6.16530214e-03 -2.41157233e-03
# -1.88324182e-03 -5.19581269e-04]
# [ 1.31213101e-01 2.13486089e-01 8.40324225e-01 -5.23468101e-02
# 4.06502430e-04 1.01527758e-03 -2.75600070e-04 5.76346878e-04
# 1.79444495e-04 -2.19983885e-04 8.45585552e-04 1.24013980e-02
# -9.48056397e-02 2.31166662e-02 1.49989514e-05 3.59930492e-04
# 3.83840527e-04 4.25616208e-04 4.11711911e-05 1.00135535e-04
# 7.49807186e-02 -1.21167279e-01 -4.44630524e-01 2.01806772e-02
# -1.47871511e-03 -9.48569782e-03 -1.04511092e-02 -1.59681971e-03
# -5.47852368e-03 -1.23726579e-03]
# [-3.35131912e-02 7.84253475e-01 -1.89074737e-01 7.33787337e-03
# -1.60796958e-03 -2.77107786e-04 -1.02365525e-03 -9.05454729e-04
# -5.98298140e-04 4.25619565e-05 1.53826412e-02 -6.66867308e-02
# 1.48548561e-01 -2.25977534e-02 2.37177063e-04 1.27405510e-03
# 1.41036865e-03 5.21614746e-04 7.13773114e-04 1.94572501e-04
# -4.53901657e-02 -5.52024144e-01 1.17015608e-01 -1.83169390e-03
# -3.94704099e-03 -7.74390329e-03 -1.08822097e-02 -4.24156865e-03
# -7.03799527e-03 -1.17067750e-03]
# [-7.54924585e-02 -6.87405638e-02 8.39642267e-02 -3.00992471e-03
# 3.43658580e-03 1.55731486e-02 1.92512587e-02 9.07295722e-03
# 9.14981253e-03 3.00903313e-03 8.41264091e-02 5.87281487e-01
# 7.77894165e-01 -4.22340772e-02 1.60686058e-03 8.68045428e-03
# 1.17908464e-02 3.99536083e-03 5.54864404e-03 1.44055227e-03
# -1.29429315e-01 1.60158677e-02 -7.32396132e-02 5.01077221e-03
# -2.13704577e-03 -8.59643463e-03 -5.96017341e-03 4.06900452e-04
# -7.69020238e-03 -9.48038255e-05]
# [-3.50549264e-01 4.08376429e-03 1.32828034e-01 -3.82916116e-03
# 8.22698130e-03 5.63148308e-02 7.02297025e-02 1.92498100e-02
# 1.49895864e-02 7.63859418e-03 -9.15127909e-02 1.62432584e-02
# -1.90207008e-01 4.47806986e-03 1.02379224e-03 1.89362797e-02
# 2.52133644e-02 2.90414428e-03 1.76381905e-03 2.10548292e-03
# -8.60507446e-01 -2.11047341e-03 3.86007208e-02 4.09978696e-03
# 1.14296358e-02 1.56279540e-01 1.91177497e-01 3.09293509e-02
# 2.02890503e-02 1.76591684e-02]
# [-1.39559852e-01 -7.66679112e-02 8.92113884e-02 -1.95571374e-03
# 4.44685274e-03 2.99475403e-02 2.79441157e-02 1.04362500e-02
# 8.27800256e-03 4.54280518e-03 3.22073950e-02 -7.91295671e-01
# 5.48501632e-01 -2.03427164e-02 -4.45191352e-04 6.64830146e-03
# 9.88467905e-03 2.63444821e-03 -1.06193376e-03 6.29812931e-04
# -2.83270492e-02 9.44156308e-02 -8.41866696e-02 3.10937022e-03
# 1.17730195e-02 8.54753893e-02 1.00411922e-01 2.84427467e-02
# 2.97280886e-02 1.21927538e-02]]
assert np.allclose(sk_pca_res, cu_pca_res, rtol=1e-8), "Transform results differ!"
assert np.allclose(sk_pca_comp, cu_pca_comp, rtol=1e-8), "Components differ!"
Loading