Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 335 additions & 5 deletions btas/generic/converge_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include <vector>
#include <iomanip>
#include <btas/generic/dot_impl.h>
#include <btas/generic/contract.h>
#include <btas/generic/reconstruct.h>
#include <btas/generic/scal_impl.h>
#include <btas/varray/varray.h>

namespace btas {
Expand All @@ -23,7 +26,7 @@ namespace btas {
public:
/// constructor for the base convergence test object
/// \param[in] tol tolerance for ALS convergence
explicit NormCheck(double tol = 1e-3) : tol_(tol) {
explicit NormCheck(double tol = 1e-3) : tol_(tol), iter_(0) {
}

~NormCheck() = default;
Expand All @@ -50,17 +53,34 @@ namespace btas {
prev[r] = btas_factors[r];
}

if (verbose_) {
std::cout << rank_ << "\t" << iter_ << "\t" << std::setprecision(16) << diff << std::endl;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to use proper logging ... in another PR

}
if (diff < this->tol_) {
return true;
}
++iter_;

return false;
}

/// Option to print fit and change in fit in the () operator call
/// \param[in] verb bool which turns off/on fit printing.
void verbose(bool verb) {
verbose_ = verb;
}

double get_fit(bool hit_max_iters = false){

}

private:
double tol_;
std::vector<Tensor> prev; // Set of previous factor matrices
size_t ndim; // Number of factor matrices
ind_t rank_; // Rank of the CP problem
bool verbose_ = false;
size_t iter_;
};

/**
Expand Down Expand Up @@ -105,7 +125,7 @@ namespace btas {
}

double normFactors = norm(btas_factors, V);
double normResidual = sqrt(abs(normT_ * normT_ + normFactors * normFactors - 2 * abs(iprod)));
double normResidual = sqrt(abs(normT_ * normT_ + normFactors - 2 * abs(iprod)));
double fit = 1. - (normResidual / normT_);

double fitChange = abs(fitOld_ - fit);
Expand Down Expand Up @@ -208,11 +228,11 @@ namespace btas {
}
}

dtype nrm = 0.0;
RT nrm = 0.0;
for (auto &i : coeffMat) {
nrm += i;
nrm += std::real(i);
}
return sqrt(abs(nrm));
return nrm;
}
};

Expand Down Expand Up @@ -421,5 +441,315 @@ namespace btas {
return sqrt(abs(nrm));
}
};

/**
\breif Class used to decide when ALS problem is converged.
The fit is not computed and the optimization just runs until nALS is
reached.
**/
template <typename Tensor>
class NoCheck {
using ind_t = typename Tensor::range_type::index_type::value_type;
using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;

public:
/// constructor for the base convergence test object
/// \param[in] tol tolerance for ALS convergence
explicit NoCheck(double tol = 1e-3) : iter_(0){
}

~NoCheck() = default;

/// Function to check convergence of the ALS problem
/// convergence when \f$ \sum_n^{ndim} \frac{\|A^{i}_n - A^{i+1}_n\|}{dim(A^{i}_n} \leq \epsilon \f$
/// \param[in] btas_factors Current set of factor matrices
bool operator () (const std::vector<Tensor> &btas_factors,
const std::vector<Tensor> & V = std::vector<Tensor>()){
auto rank_ = btas_factors[1].extent(1);
if (verbose_) {
std::cout << rank_ << "\t" << iter_ << std::endl;
}
++iter_;

return false;
}

/// Option to print fit and change in fit in the () operator call
/// \param[in] verb bool which turns off/on fit printing.
void verbose(bool verb) {
verbose_ = verb;
}

double get_fit(bool hit_max_iters = false){

}

private:
double tol_;
bool verbose_ = false;
size_t iter_;
Tensor prevT_;
};

/// This class is going to take a tensor approximation
/// and compare it to the previous tensor approximation
/// Skipping the total fit and directly computing the relative fit
template <typename Tensor>
class ApproxFitCheck{
using ind_t = typename Tensor::range_type::index_type::value_type;
using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;

public:
/// constructor for the base convergence test object
/// \param[in] tol tolerance for ALS convergence
explicit ApproxFitCheck(double tol = 1e-3) : iter_(0), tol_(tol){
}

~ApproxFitCheck() = default;

/// Function to check convergence of the ALS problem
/// convergence when \f$ \sum_n^{ndim} \frac{\|A^{i}_n - A^{i+1}_n\|}{dim(A^{i}_n} \leq \epsilon \f$
/// \param[in] btas_factors Current set of factor matrices

bool operator () (std::vector<Tensor> & btas_factors,
const std::vector<Tensor> & V = std::vector<Tensor>()) {
auto rank_ = btas_factors[1].extent(1);

auto fit = 0.0;
if(iter_ == 0) {
fit_prev_ = (norm(btas_factors, btas_factors, rank_));
norm_prev_ = sqrt(fit_prev_);
prev_factors = btas_factors;
// diff = reconstruct(btas_factors, orders);
if (verbose_) {
std::cout << rank_ << "\t" << iter_ << "\t" << 1.0 << std::endl;
}
++iter_;
return false;
}

auto curr_norm = norm(btas_factors, btas_factors, rank_);
fit = sqrt(fit_prev_ - 2 * norm(prev_factors, btas_factors, rank_) + curr_norm) / norm_prev_;
// fit = norm(diff);
// diff = tnew;
fit_prev_ = curr_norm;
norm_prev_ = sqrt(curr_norm);
prev_factors = btas_factors;

if (verbose_) {
std::cout << rank_ << "\t" << iter_ << "\t" << fit << std::endl;
}
++iter_;
if (fit < tol_) {
++converged_num;
if(converged_num > 1) {
iter_ = 0;
return true;
}
}
return false;
}

void verbose(bool verb){
verbose_ = verb;
}

private:
double tol_;
bool verbose_ = false;
double fit_prev_, norm_prev_;
std::vector<size_t> orders;
std::vector<Tensor> prev_factors;
// Tensor diff;
size_t converged_num = 0;
size_t iter_;

double norm(Tensor& a){
auto n = 0.0;
for(auto & i : a)
n += i * i;
return sqrt(n);
}

double norm(std::vector<Tensor> & facs1, std::vector<Tensor>& facs2, ind_t rank_){
BTAS_ASSERT(facs1.size() == facs2.size());
ind_t num_factors = facs1.size();
Tensor RRp;
Tensor T1 = facs1[0], T2 = facs2[0];
auto lam_ptr1 = facs1[num_factors - 1].data(),
lam_ptr2 = facs2[num_factors - 1].data();
for (ind_t i = 0; i < rank_; i++) {
scal(T1.extent(0), *(lam_ptr1 + i), std::begin(T1) + i, rank_);
scal(T2.extent(0), *(lam_ptr2 + i), std::begin(T2) + i, rank_);
}

contract(1.0, T1, {1,2}, T2, {1,3}, 0.0, RRp, {2,3});

for (ind_t i = 0; i < rank_; i++) {
auto val1 = *(lam_ptr1 + i),
val2 = *(lam_ptr2 + i);
scal(T1.extent(0), (abs(val1) > 1e-12 ? 1.0/val1 : 1.0), std::begin(T1) + i, rank_);
scal(T2.extent(0), (abs(val2) > 1e-12 ? 1.0/val2 : 1.0), std::begin(T2) + i, rank_);
}

auto * ptr_RRp = RRp.data();
for (ind_t i = 1; i < num_factors - 3; ++i) {
Tensor temp;
contract(1.0, facs1[i], {1,2}, facs2[i], {1,3}, 0.0, temp, {2,3});
auto * ptr_temp = temp.data();
for(ord_t r = 0; r < rank_ * rank_; ++r)
*(ptr_RRp + r) *= *(ptr_temp + r);
}
Tensor temp;
auto last = num_factors - 2;
contract(1.0, facs1[last], {1,2}, facs2[last], {1,3}, 0.0, temp, {2,3});
return btas::dot(RRp, temp);
}

};

/**
\breif This is a class that computes the difference in two fits
/| T - T^{i} \|^2 - \| T - T^{i + 1}\|^2 = T^{i}^2 - 2 TT^{i} + 2 TT^{i+1} - T^{i+1}^2
**/
template <typename Tensor>
class DiffFitCheck{
using ind_t = typename Tensor::range_type::index_type::value_type;
using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;
using dtype = typename Tensor::value_type;

public:
/// constructor for the base convergence test object
/// \param[in] tol tolerance for ALS convergence
explicit DiffFitCheck(double tol = 1e-3) : iter_(0), tol_(tol){
}

~DiffFitCheck() = default;

/// Function to check convergence of the ALS problem
/// convergence when \f$ \sum_n^{ndim} \frac{\|A^{i}_n - A^{i+1}_n\|}{dim(A^{i}_n} \leq \epsilon \f$
/// \param[in] btas_factors Current set of factor matrices

bool operator () (std::vector<Tensor> & btas_factors,
const std::vector<Tensor> & V = std::vector<Tensor>()) {
auto rank_ = btas_factors[1].extent(1);
auto n = btas_factors.size() - 1;
auto & lambda = btas_factors[n];
auto fit = 0.0;
if(iter_ == 0) {
fit_prev_ = sqrt(abs(norm(V, lambda, rank_) - 2.0 * abs(compute_inner_product(btas_factors[n - 1], lambda))));
if (verbose_) {
std::cout << rank_ << "\t" << iter_ << "\t" << 1.0 << std::endl;
}
++iter_;
return false;
}

auto curr_norm = sqrt(abs(norm(V, lambda, rank_) - 2.0 * abs(compute_inner_product(btas_factors[n - 1], lambda))));
fit = sqrt(abs(fit_prev_ * fit_prev_ - curr_norm * curr_norm));
fit_prev_ = curr_norm;

if (verbose_) {
std::cout << rank_ << "\t" << iter_ << "\t" << fit << std::endl;
}
++iter_;
if (fit < tol_) {
++converged_num;
if(converged_num > 1) {
return true;
}
}
return false;
}

void verbose(bool verb){
verbose_ = verb;
}

void set_MtKRP(Tensor & MtKRP){
MtKRP_ = MtKRP;
}

private:
double tol_;
bool verbose_ = false;
double fit_prev_;
Tensor MtKRP_;
size_t converged_num = 0;
size_t iter_;

dtype compute_inner_product(Tensor &last_factor, Tensor& lambda){
ord_t size = last_factor.size();
ind_t rank = last_factor.extent(1);
auto *ptr_A = last_factor.data();
auto *ptr_MtKRP = MtKRP_.data();
auto lam_ptr = lambda.data();
dtype iprod = 0.0;
for (ord_t i = 0; i < size; ++i) {
iprod += *(ptr_MtKRP + i) * btas::impl::conj(*(ptr_A + i)) * *(lam_ptr + i % rank);
}
return iprod;
}

double norm(const std::vector<Tensor> &V, Tensor & lambda, ind_t rank_) {
auto n = V.size();
Tensor coeffMat;
typename Tensor::value_type one = 1.0;
ger(one, lambda.conj(), lambda, coeffMat);

auto rank2 = rank_ * (ord_t)rank_;
Tensor temp(rank_, rank_);

auto *ptr_coeff = coeffMat.data();
for (size_t i = 0; i < n; ++i) {
auto *ptr_V = V[i].data();
for (ord_t j = 0; j < rank2; ++j) {
*(ptr_coeff + j) *= *(ptr_V + j);
}
}

dtype nrm = 0.0;
for (auto &i : coeffMat) {
nrm += i;
}
return nrm;
}

double norm(std::vector<Tensor> & facs1, std::vector<Tensor>& facs2, ind_t rank_){
BTAS_ASSERT(facs1.size() == facs2.size());
ind_t num_factors = facs1.size();
Tensor RRp;
Tensor T1 = facs1[0], T2 = facs2[0];
auto lam_ptr1 = facs1[num_factors - 1].data(),
lam_ptr2 = facs2[num_factors - 1].data();
for (ind_t i = 0; i < rank_; i++) {
scal(T1.extent(0), *(lam_ptr1 + i), std::begin(T1) + i, rank_);
scal(T2.extent(0), *(lam_ptr2 + i), std::begin(T2) + i, rank_);
}

contract(1.0, T1, {1,2}, T2, {1,3}, 0.0, RRp, {2,3});

for (ind_t i = 0; i < rank_; i++) {
auto val1 = *(lam_ptr1 + i),
val2 = *(lam_ptr2 + i);
scal(T1.extent(0), (abs(val1) > 1e-12 ? 1.0/val1 : 1.0), std::begin(T1) + i, rank_);
scal(T2.extent(0), (abs(val2) > 1e-12 ? 1.0/val2 : 1.0), std::begin(T2) + i, rank_);
}

auto * ptr_RRp = RRp.data();
for (ind_t i = 1; i < num_factors - 3; ++i) {
Tensor temp;
contract(1.0, facs1[i], {1,2}, facs2[i], {1,3}, 0.0, temp, {2,3});
auto * ptr_temp = temp.data();
for(ord_t r = 0; r < rank_ * rank_; ++r)
*(ptr_RRp + r) *= *(ptr_temp + r);
}
Tensor temp;
auto last = num_factors - 2;
contract(1.0, facs1[last], {1,2}, facs2[last], {1,3}, 0.0, temp, {2,3});
return btas::dot(RRp, temp);
}

};
} //namespace btas
#endif // BTAS_GENERIC_CONV_BASE_CLASS
Loading
Loading