Skip to content

Commit ca4cc18

Browse files
committed
xt::linalg::kron: support arguments of arbitrary number of dimensions
Before this commit, `xt::linalg::kron` only supports 2D arguments. This commit proposes to add support for argument with any number of dimensions. This change of behavior is coherent with what `numpy.kron` does. Given the implementation, it would probably make sense to return a xexpression instead of a xarray and allow lazy evaluation. This might be done in a separate commit. It could also be possible to have a dynamic check of the number of dimensions and use specialized implementation for the more common cases (i.e. 2D) at runtime, which should be more efficient. Tested with `./test/test_xtensor_blas --gtest_filter=xlinalg.kron*`.
1 parent 7ceb791 commit ca4cc18

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

include/xtensor-blas/xlinalg.hpp

+34-14
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#define XLINALG_HPP
1212

1313
#include <algorithm>
14+
#include <functional>
1415
#include <limits>
1516
#include <sstream>
1617
#include <chrono>
@@ -1460,7 +1461,7 @@ namespace linalg
14601461
}
14611462

14621463
/**
1463-
* Calculate the Kronecker product between two 2D xexpressions.
1464+
* Calculate the Kronecker product between two kD xexpressions.
14641465
*/
14651466
template <class T, class E>
14661467
auto kron(const xexpression<T>& a, const xexpression<E>& b)
@@ -1470,23 +1471,42 @@ namespace linalg
14701471
const auto& da = a.derived_cast();
14711472
const auto& db = b.derived_cast();
14721473

1473-
XTENSOR_ASSERT(da.dimension() == 2);
1474-
XTENSOR_ASSERT(db.dimension() == 2);
1474+
XTENSOR_ASSERT(da.dimension() == bd.dimension());
14751475

1476-
std::array<std::size_t, 2> shp = {da.shape()[0] * db.shape()[0], da.shape()[1] * db.shape()[1]};
1477-
xtensor<value_type, 2> res(shp);
1476+
const auto shapea = da.shape();
1477+
const auto shapeb = db.shape();
1478+
const std::vector<std::size_t> shp = [&shapea, &shapeb](){
1479+
std::vector<std::size_t> r;
1480+
r.reserve(shapea.size());
1481+
std::transform(shapea.begin(), shapea.end(), shapeb.begin(),
1482+
std::back_inserter(r), std::multiplies<std::size_t>());
1483+
return r;
1484+
}();
14781485

1479-
for (std::size_t i = 0; i < da.shape()[0]; ++i)
1486+
xarray<value_type> res(shp);
1487+
1488+
std::vector<std::size_t> ires(da.dimension(), 0);
1489+
std::vector<std::size_t> ia(da.dimension(), 0);
1490+
std::vector<std::size_t> ib(da.dimension(), 0);
1491+
1492+
for (size_t i = 0; i < res.size(); i++)
14801493
{
1481-
for (std::size_t j = 0; j < da.shape()[1]; ++j)
1494+
for (size_t j = 0; j < shp.size(); j++)
14821495
{
1483-
for (std::size_t k = 0; k < db.shape()[0]; ++k)
1484-
{
1485-
for (std::size_t h = 0; h < db.shape()[1]; ++h)
1486-
{
1487-
res(i * db.shape()[0] + k, j * db.shape()[1] + h) = da(i, j) * db(k, h);
1488-
}
1489-
}
1496+
ia[j] = ires[j] / shapeb[j];
1497+
ib[j] = ires[j] % shapeb[j];
1498+
}
1499+
1500+
res[ires] = da[ia] * db[ib];
1501+
1502+
// Figure out the index of the next element
1503+
size_t j = ires.size() - 1;
1504+
ires[j]++;
1505+
while (ires[j] >= shp[j] && j > 0)
1506+
{
1507+
ires[j] = 0;
1508+
ires[j - 1]++;
1509+
j--;
14901510
}
14911511
}
14921512

test/test_linalg.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,22 @@ namespace xt
366366
EXPECT_EQ(expected, res);
367367
}
368368

369+
TEST(xlinalg, kron_3d)
370+
{
371+
xarray<int> arg_0 = {{{1, 2, 3}}};
372+
373+
xarray<int> arg_1 = xt::ones<int>({2, 2, 1});
374+
375+
auto res = xt::linalg::kron(arg_0, arg_1);
376+
377+
xarray<int> expected = {{{1, 2, 3},
378+
{1, 2, 3}},
379+
{{1, 2, 3},
380+
{1, 2, 3}}};
381+
382+
EXPECT_EQ(expected, res);
383+
}
384+
369385
TEST(xlinalg, cholesky)
370386
{
371387
xarray<double> arg_0 = {{ 4, 12,-16},

0 commit comments

Comments
 (0)