diff --git a/R-package/R/metrics.R b/R-package/R/metrics.R index e7099e9483e0..12a8047b9915 100644 --- a/R-package/R/metrics.R +++ b/R-package/R/metrics.R @@ -24,6 +24,7 @@ , "map" = TRUE , "auc" = TRUE , "average_precision" = TRUE + , "r2" = TRUE , "binary_logloss" = FALSE , "binary_error" = FALSE , "auc_mu" = TRUE diff --git a/docs/Parameters.rst b/docs/Parameters.rst index d67badb63f59..3a4d880ef2e0 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -1261,6 +1261,8 @@ Metric Parameters - ``average_precision``, `average precision score `__ + - ``r2``, `R-squared `__ + - ``binary_logloss``, `log loss `__, aliases: ``binary`` - ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 00dc9ba548c8..60966f54e3ea 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -1028,6 +1028,7 @@ struct Config { // descl2 = ``map``, `MAP `__, aliases: ``mean_average_precision`` // descl2 = ``auc``, `AUC `__ // descl2 = ``average_precision``, `average precision score `__ + // descl2 = ``r2``, `R-squared `__ // descl2 = ``binary_logloss``, `log loss `__, aliases: ``binary`` // descl2 = ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification // descl2 = ``auc_mu``, `AUC-mu `__ diff --git a/src/metric/metric.cpp b/src/metric/metric.cpp index e773f61ae01a..ad38a505045d 100644 --- a/src/metric/metric.cpp +++ b/src/metric/metric.cpp @@ -78,6 +78,9 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) { return new CUDAGammaDevianceMetric(config); } else if (type == std::string("tweedie")) { return new CUDATweedieMetric(config); + } else if (type == std::string("r2")) { + Log::Warning("Metric r2 is not implemented in cuda version. Fall back to evaluation on CPU."); + return new R2Metric(config); } } else { #endif // USE_CUDA @@ -127,6 +130,8 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) { return new GammaDevianceMetric(config); } else if (type == std::string("tweedie")) { return new TweedieMetric(config); + } else if (type == std::string("r2")) { + return new R2Metric(config); } #ifdef USE_CUDA } diff --git a/src/metric/regression_metric.hpp b/src/metric/regression_metric.hpp index fbe6f2a062fb..b1f276226454 100644 --- a/src/metric/regression_metric.hpp +++ b/src/metric/regression_metric.hpp @@ -318,5 +318,115 @@ class TweedieMetric : public RegressionMetric { }; +class R2Metric: public Metric { + public: + explicit R2Metric(const Config& config) :config_(config) {} + const std::vector& GetName() const override { + return name_; + } + + double factor_to_bigger_better() const override { + return 1.0f; + } + + void Init(const Metadata& metadata, data_size_t num_data) override { + name_.emplace_back("r2"); + num_data_ = num_data; + label_ = metadata.label(); + weights_ = metadata.weights(); + + double sum_label = 0.0f; + if (weights_ == nullptr) { + sum_weights_ = static_cast(num_data_); + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:sum_label) + for (data_size_t i = 0; i < num_data_; ++i) { + sum_label += label_[i]; + } + } else { + double local_sum_weights = 0.0f; + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:local_sum_weights, sum_label) + for (data_size_t i = 0; i < num_data_; ++i) { + local_sum_weights += weights_[i]; + sum_label += label_[i] * weights_[i]; + } + sum_weights_ = local_sum_weights; + } + label_mean_ = sum_label / sum_weights_; + + total_sum_squares_ = 0.0f; + double local_total_sum_squares = 0.0f; + if (weights_ == nullptr) { + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:local_total_sum_squares) + for (data_size_t i = 0; i < num_data_; ++i) { + double diff = label_[i] - label_mean_; + local_total_sum_squares += diff * diff; + } + } else { + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:local_total_sum_squares) + for (data_size_t i = 0; i < num_data_; ++i) { + double diff = label_[i] - label_mean_; + local_total_sum_squares += diff * diff * weights_[i]; + } + } + total_sum_squares_ = local_total_sum_squares; + } + + std::vector Eval(const double* score, const ObjectiveFunction* objective) const override { + double residual_sum_squares = 0.0f; + if (objective == nullptr) { + if (weights_ == nullptr) { + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:residual_sum_squares) + for (data_size_t i = 0; i < num_data_; ++i) { + double diff = label_[i] - score[i]; + residual_sum_squares += diff * diff; + } + } else { + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:residual_sum_squares) + for (data_size_t i = 0; i < num_data_; ++i) { + double diff = label_[i] - score[i]; + residual_sum_squares += diff * diff * weights_[i]; + } + } + } else { + if (weights_ == nullptr) { + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:residual_sum_squares) + for (data_size_t i = 0; i < num_data_; ++i) { + double t = 0; + objective->ConvertOutput(&score[i], &t); + double diff = label_[i] - t; + residual_sum_squares += diff * diff; + } + } else { + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:residual_sum_squares) + for (data_size_t i = 0; i < num_data_; ++i) { + double t = 0; + objective->ConvertOutput(&score[i], &t); + double diff = label_[i] - t; + residual_sum_squares += diff * diff * weights_[i]; + } + } + } + + double r2 = 1.0 - (residual_sum_squares / total_sum_squares_); + if (std::fabs(total_sum_squares_) < kZeroThreshold) { + return std::vector(1, std::fabs(residual_sum_squares) < kZeroThreshold ? 1.0 : 0.0); + } + return std::vector(1, r2); + } + + protected: + data_size_t num_data_; + const label_t* label_; + const label_t* weights_; + double sum_weights_; + Config config_; + std::vector name_; + + // Custom members for R2 calculation + double label_mean_; + double total_sum_squares_; +}; + + } // namespace LightGBM #endif // LightGBM_METRIC_REGRESSION_METRIC_HPP_ diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 1008e71ae14b..9c9c96a977fb 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -16,7 +16,14 @@ import pytest from scipy.sparse import csr_matrix, isspmatrix_csc, isspmatrix_csr from sklearn.datasets import load_svmlight_file, make_blobs, make_classification, make_multilabel_classification -from sklearn.metrics import average_precision_score, log_loss, mean_absolute_error, mean_squared_error, roc_auc_score +from sklearn.metrics import ( + average_precision_score, + log_loss, + mean_absolute_error, + mean_squared_error, + r2_score, + roc_auc_score, +) from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split import lightgbm as lgb @@ -4049,6 +4056,29 @@ def test_average_precision_metric(): assert res["training"]["average_precision"][-1] == pytest.approx(1) +def test_r2_metric(): + # test against sklearn R2 metric + X, y = make_synthetic_regression() + params = {"objective": "regression", "metric": "r2", "verbose": -1} + res = {} + train_data = lgb.Dataset(X, label=y) + est = lgb.train( + params, train_data, num_boost_round=1, valid_sets=[train_data], callbacks=[lgb.record_evaluation(res)] + ) + r2 = res["training"]["r2"][-1] + pred = est.predict(X) + sklearn_r2 = r2_score(y, pred) + assert r2 == pytest.approx(sklearn_r2) + assert r2 != 0 + assert r2 != 1 + # test that R2 is 1 when y has no variance and the model predicts perfectly + y = y.copy() + y[:] = 1 + lgb_X = lgb.Dataset(X, label=y) + lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], callbacks=[lgb.record_evaluation(res)]) + assert res["training"]["r2"][-1] == pytest.approx(1) + + def test_reset_params_works_with_metric_num_class_and_boosting(): X, y = load_breast_cancer(return_X_y=True) dataset_params = {"max_bin": 150}