Skip to content

Commit 41ddfb4

Browse files
committed
add library for 2D polyfitting
this uses the [ceres solver] to implement polyfitting using least squares for a given set of 2D points to find a two-dimensional polynomial ($y = a0 + a1 * x + a2 * x^2$). this can then be used in a next step to fit the line through the banana contour (will be done in a follow-up commit). note that the `polyfit` header-only library is not specific to the current project and could easily be factored out (it has been built in a way to not rely on the rest of the project). it could potentially also be made more abstract to support generic polyfitting for n-dimensional polynomials and/or n-dimensional data points. the API accepts a `range` to accept any form of iterable container. this requires at least C++20, most likely C++23 (where ranges have been updated significantly). it has only been tested against C++23. see the [`std::ranges` documentation] for further details. note that this is using ceres solver v2.1.0 while v2.2.0 has already been released. the reason for this is that v2.1.0 is the latest available version in vcpkg, see microsoft/vcpkg#34483. to better show what is happening a jupyter notebook is provided, visualising the solution approach using python. due to this `.gitignore` had to be updated as jupyterlab stores additional cache files. [ceres solver]: http://ceres-solver.org/ [`std::ranges` documentation]: https://en.cppreference.com/w/cpp/ranges
1 parent d5ba857 commit 41ddfb4

7 files changed

+402
-2
lines changed

.gitignore

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
.idea
2-
cmake-build*
1+
.idea/
2+
cmake-build*/
3+
.ipynb_checkpoints/

docs/banana-center-line-fitting.ipynb

+255
Large diffs are not rendered by default.

include/polyfit/Polynomial2DFit.hpp

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/**\file
2+
* \brief Header-only library which provides a simple API to fit a 2-dimensional polynomial (i.e. 3 coefficients) to a set of two-dimensional points (x, y coordinates).
3+
*/
4+
5+
#ifndef BANANA_PROJECT_POLYNOMIAL2DFIT_HPP
6+
#define BANANA_PROJECT_POLYNOMIAL2DFIT_HPP
7+
8+
#include <expected>
9+
#include <tuple>
10+
#include <utility>
11+
#include <vector>
12+
13+
// TODO: we currently leak the ceres dependency to everyone which includes this header-only library.
14+
// when changing this we'll also need to change the return type of `Fit2DPolynomial` which currently returns
15+
// the ceres termination type in case of errors.
16+
#include <ceres/ceres.h>
17+
18+
/**
19+
* \brief Header-only library which provides a simple API to fit a 2-dimensional polynomial (i.e. 3 coefficients) to a set of two-dimensional points (x, y coordinates).
20+
*/
21+
namespace polyfit {
22+
23+
/** Internal helpers, do not use from the outside! */
24+
namespace internal {
25+
/**
26+
* Calculates the residual (= error = distance) of the estimate for a specified point.
27+
*/
28+
struct Polynomial2DResidual {
29+
public:
30+
Polynomial2DResidual(double const x, double const y) : x_(x), y_(y) {}
31+
32+
template<typename T>
33+
bool operator()(const T *const a0, const T *const a1, const T *const a2, T *residual) const {
34+
auto const y_estimate = a0[0] + a1[0] * x_ + a2[0] * x_ * x_;
35+
residual[0] = y_ - y_estimate;
36+
return true;
37+
}
38+
39+
private:
40+
double const x_;
41+
double const y_;
42+
};
43+
}
44+
45+
/**
46+
* Calculate the coefficients for a two-dimensional polynomial which fits the points as well as possible.
47+
*
48+
* @tparam R the underlying container supporting ranges. must contain std::tuple<double, double>
49+
* @tparam print_report if enabled the underlying solver will print a detailed report to STDOUT. helpful for debugging.
50+
* @param points the set of points (x & y coordinates) for which the polynomial should be fitted.
51+
* @return either the set of coefficients or the termination type indicating why the fitting failed.
52+
*/
53+
template<std::ranges::range R, bool print_report = false>
54+
[[nodiscard]]
55+
auto Fit2DPolynomial(
56+
R&& points) -> std::expected<std::tuple<double, double, double>, ceres::TerminationType> {
57+
double a0 = 1, a1 = 1, a2 = 1;
58+
59+
ceres::Problem problem;
60+
for (auto const& point: points) {
61+
ceres::CostFunction *cost_function =
62+
new ceres::AutoDiffCostFunction<internal::Polynomial2DResidual, 1, 1, 1, 1>(
63+
new internal::Polynomial2DResidual(point.first, point.second));
64+
problem.AddResidualBlock(cost_function, nullptr, &a0, &a1, &a2);
65+
}
66+
67+
ceres::Solver::Options options{
68+
.logging_type = print_report ? ceres::PER_MINIMIZER_ITERATION : ceres::SILENT,
69+
.minimizer_progress_to_stdout = print_report,
70+
};
71+
ceres::Solver::Summary summary;
72+
ceres::Solve(options, &problem, &summary);
73+
if (print_report) {
74+
std::cout << summary.FullReport() << std::endl;
75+
}
76+
if (summary.termination_type == ceres::CONVERGENCE) {
77+
return {{a0, a1, a2}};
78+
} else {
79+
return std::unexpected{summary.termination_type};
80+
}
81+
};
82+
83+
}
84+
85+
#endif //BANANA_PROJECT_POLYNOMIAL2DFIT_HPP

src/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
set(BANANA_HEADER_LIST "${PROJECT_SOURCE_DIR}/include/banana-lib/lib.hpp")
22

33
find_package(OpenCV CONFIG REQUIRED)
4+
find_package(Ceres CONFIG REQUIRED)
45

56
add_library(banana-lib lib.cpp ${BANANA_HEADER_LIST})
67

@@ -12,4 +13,5 @@ target_include_directories(
1213

1314
target_link_libraries(banana-lib
1415
PUBLIC ${OpenCV_LIBS}
16+
PRIVATE Ceres::ceres
1517
)

test/CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
find_package(GTest CONFIG REQUIRED)
22
include(GoogleTest)
33

4+
find_package(Ceres CONFIG REQUIRED)
5+
6+
add_executable(polyfit-test polyfit-test.cpp)
7+
target_include_directories(polyfit-test PRIVATE "${PROJECT_SOURCE_DIR}/include")
8+
target_link_libraries(polyfit-test Ceres::ceres GTest::gtest_main)
9+
gtest_discover_tests(polyfit-test)
10+
411
add_executable(banana-lib-test banana-lib-test.cpp)
512
target_link_libraries(banana-lib-test banana-lib GTest::gtest_main)
613
gtest_discover_tests(banana-lib-test)

test/polyfit-test.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <polyfit/Polynomial2DFit.hpp>
4+
5+
#define ASSERT_COEFFS_NEAR(coeff_expected_0, coeff_expected_1, coeff_expected_2, coeffs) \
6+
do { \
7+
ASSERT_NEAR(coeff_expected_0, std::get<0>(coeffs), 1e-6); \
8+
ASSERT_NEAR(coeff_expected_1, std::get<1>(coeffs), 1e-6); \
9+
ASSERT_NEAR(coeff_expected_2, std::get<2>(coeffs), 1e-6); \
10+
} while (false)
11+
12+
/** y = 1 + x */
13+
TEST(Polynomial2DFitTestSuite, FitSimpleLine) {
14+
std::vector<std::pair<double, double>> points = {
15+
{0,1},
16+
{1,2},
17+
{2,3},
18+
{3,4},
19+
};
20+
auto const result = polyfit::Fit2DPolynomial(points);
21+
ASSERT_TRUE(result);
22+
ASSERT_COEFFS_NEAR(1, 1, 0, *result);
23+
}
24+
25+
/** y = -1 + x^2 */
26+
TEST(Polynomial2DFitTestSuite, FitSimpleCurve) {
27+
std::vector<std::pair<double, double>> points = {
28+
{-1,0},
29+
{0,-1},
30+
{1,0},
31+
};
32+
auto const result = polyfit::Fit2DPolynomial(points);
33+
ASSERT_TRUE(result);
34+
ASSERT_COEFFS_NEAR(-1, 0, 1, *result);
35+
}
36+
37+
/** y = -1 + 3*x + 2*x^2 */
38+
TEST(Polynomial2DFitTestSuite, FitSimpleCurve2) {
39+
std::vector<std::pair<double, double>> points = {
40+
{-1,-2},
41+
{0,-1},
42+
{1,4},
43+
};
44+
auto const result = polyfit::Fit2DPolynomial(points);
45+
ASSERT_TRUE(result);
46+
ASSERT_COEFFS_NEAR(-1, 3, 2, *result);
47+
}

vcpkg.json

+3
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,8 @@
88
}, {
99
"name" : "opencv",
1010
"version>=" : "4.8.0#1"
11+
}, {
12+
"name" : "ceres",
13+
"version>=" : "2.1.0#5"
1114
} ]
1215
}

0 commit comments

Comments
 (0)