|
| 1 | +/**\file |
| 2 | + * \brief Header-only library which provides a simple API to fit a 2-dimensional polynomial (i.e. 3 coefficients) to a set of two-dimensional points (x, y coordinates). |
| 3 | + */ |
| 4 | + |
| 5 | +#ifndef BANANA_PROJECT_POLYNOMIAL2DFIT_HPP |
| 6 | +#define BANANA_PROJECT_POLYNOMIAL2DFIT_HPP |
| 7 | + |
| 8 | +#include <expected> |
| 9 | +#include <tuple> |
| 10 | +#include <utility> |
| 11 | +#include <vector> |
| 12 | + |
| 13 | +// TODO: we currently leak the ceres dependency to everyone which includes this header-only library. |
| 14 | +// when changing this we'll also need to change the return type of `Fit2DPolynomial` which currently returns |
| 15 | +// the ceres termination type in case of errors. |
| 16 | +#include <ceres/ceres.h> |
| 17 | + |
| 18 | +/** |
| 19 | + * \brief Header-only library which provides a simple API to fit a 2-dimensional polynomial (i.e. 3 coefficients) to a set of two-dimensional points (x, y coordinates). |
| 20 | + */ |
| 21 | +namespace polyfit { |
| 22 | + |
| 23 | + /** Internal helpers, do not use from the outside! */ |
| 24 | + namespace internal { |
| 25 | + /** |
| 26 | + * Calculates the residual (= error = distance) of the estimate for a specified point. |
| 27 | + */ |
| 28 | + struct Polynomial2DResidual { |
| 29 | + public: |
| 30 | + Polynomial2DResidual(double const x, double const y) : x_(x), y_(y) {} |
| 31 | + |
| 32 | + template<typename T> |
| 33 | + bool operator()(const T *const a0, const T *const a1, const T *const a2, T *residual) const { |
| 34 | + auto const y_estimate = a0[0] + a1[0] * x_ + a2[0] * x_ * x_; |
| 35 | + residual[0] = y_ - y_estimate; |
| 36 | + return true; |
| 37 | + } |
| 38 | + |
| 39 | + private: |
| 40 | + double const x_; |
| 41 | + double const y_; |
| 42 | + }; |
| 43 | + } |
| 44 | + |
| 45 | + /** |
| 46 | + * Calculate the coefficients for a two-dimensional polynomial which fits the points as well as possible. |
| 47 | + * |
| 48 | + * @tparam R the underlying container supporting ranges. must contain std::tuple<double, double> |
| 49 | + * @tparam print_report if enabled the underlying solver will print a detailed report to STDOUT. helpful for debugging. |
| 50 | + * @param points the set of points (x & y coordinates) for which the polynomial should be fitted. |
| 51 | + * @return either the set of coefficients or the termination type indicating why the fitting failed. |
| 52 | + */ |
| 53 | + template<std::ranges::range R, bool print_report = false> |
| 54 | + [[nodiscard]] |
| 55 | + auto Fit2DPolynomial( |
| 56 | + R&& points) -> std::expected<std::tuple<double, double, double>, ceres::TerminationType> { |
| 57 | + double a0 = 1, a1 = 1, a2 = 1; |
| 58 | + |
| 59 | + ceres::Problem problem; |
| 60 | + for (auto const& point: points) { |
| 61 | + ceres::CostFunction *cost_function = |
| 62 | + new ceres::AutoDiffCostFunction<internal::Polynomial2DResidual, 1, 1, 1, 1>( |
| 63 | + new internal::Polynomial2DResidual(point.first, point.second)); |
| 64 | + problem.AddResidualBlock(cost_function, nullptr, &a0, &a1, &a2); |
| 65 | + } |
| 66 | + |
| 67 | + ceres::Solver::Options options{ |
| 68 | + .logging_type = print_report ? ceres::PER_MINIMIZER_ITERATION : ceres::SILENT, |
| 69 | + .minimizer_progress_to_stdout = print_report, |
| 70 | + }; |
| 71 | + ceres::Solver::Summary summary; |
| 72 | + ceres::Solve(options, &problem, &summary); |
| 73 | + if (print_report) { |
| 74 | + std::cout << summary.FullReport() << std::endl; |
| 75 | + } |
| 76 | + if (summary.termination_type == ceres::CONVERGENCE) { |
| 77 | + return {{a0, a1, a2}}; |
| 78 | + } else { |
| 79 | + return std::unexpected{summary.termination_type}; |
| 80 | + } |
| 81 | + }; |
| 82 | + |
| 83 | +} |
| 84 | + |
| 85 | +#endif //BANANA_PROJECT_POLYNOMIAL2DFIT_HPP |
0 commit comments