-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathmain.cpp
60 lines (46 loc) · 2.03 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/**
* @file
* @copyright Copyright 2020. Tom de Geus. All rights reserved.
* @license This project is released under the GNU Public License (MIT).
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#define FORCE_IMPORT_ARRAY
#include <xtensor-python/pyarray.hpp>
#include <xtensor-python/pytensor.hpp>
#include <xtensor.hpp>
namespace py = pybind11;
/**
* Overrides the `__name__` of a module.
* Classes defined by pybind11 use the `__name__` of the module as of the time they are defined,
* which affects the `__repr__` of the class type objects.
*/
class ScopedModuleNameOverride {
public:
explicit ScopedModuleNameOverride(py::module m, std::string name) : module_(std::move(m))
{
original_name_ = module_.attr("__name__");
module_.attr("__name__") = name;
}
~ScopedModuleNameOverride()
{
module_.attr("__name__") = original_name_;
}
private:
py::module module_;
py::object original_name_;
};
PYBIND11_MODULE(_xt, m)
{
// Ensure members to display as `xt.X` (not `xt._xt.X`)
ScopedModuleNameOverride name_override(m, "xt");
xt::import_numpy();
m.doc() = "Python bindings of xtensor";
m.def("mean", [](const xt::pyarray<double>& a) -> xt::pyarray<double> { return xt::mean(a); });
m.def("average", [](const xt::pyarray<double>& a, const xt::pyarray<double>& w) -> xt::pyarray<double> { return xt::average(a, w); });
m.def("average", [](const xt::pyarray<double>& a, const xt::pyarray<double>& w, const std::vector<ptrdiff_t>& axes) -> xt::pyarray<double> { return xt::average(a, w, axes); });
m.def("flip", [](const xt::pyarray<double>& a, ptrdiff_t axis) -> xt::pyarray<double> { return xt::flip(a, axis); });
m.def("cos", [](const xt::pyarray<double>& a) -> xt::pyarray<double> { return xt::cos(a); });
m.def("isin", [](const xt::pyarray<int>& a, const xt::pyarray<int>& b) -> xt::pyarray<bool> { return xt::isin(a, b); });
m.def("in1d", [](const xt::pyarray<int>& a, const xt::pyarray<int>& b) -> xt::pyarray<bool> { return xt::in1d(a, b); });
}