Skip to content

Commit ec3d700

Browse files
OriolAbrilricardoV94
authored andcommitted
Implement unstack operation for XTensorVariables
1 parent 6821b93 commit ec3d700

File tree

4 files changed

+156
-4
lines changed

4 files changed

+156
-4
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from pytensor.graph import node_rewriter
2-
from pytensor.tensor import broadcast_to, join, moveaxis
2+
from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
5-
from pytensor.xtensor.shape import Concat, Stack, Transpose
5+
from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack
66

77

88
@register_xcanonicalize
@@ -29,6 +29,25 @@ def lower_stack(fgraph, node):
2929
return [new_out]
3030

3131

32+
@register_xcanonicalize
33+
@node_rewriter(tracks=[UnStack])
34+
def lower_unstack(fgraph, node):
35+
x = node.inputs[0]
36+
unstacked_lengths = node.inputs[1:]
37+
axis_to_unstack = x.type.dims.index(node.op.old_dim_name)
38+
39+
x_tensor = tensor_from_xtensor(x)
40+
x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1])
41+
final_tensor = x_tensor_transposed.reshape(
42+
(*x_tensor_transposed.shape[:-1], *unstacked_lengths)
43+
)
44+
# Reintroduce any static shape information that was lost during the reshape
45+
final_tensor = specify_shape(final_tensor, node.outputs[0].type.shape)
46+
47+
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
48+
return [new_out]
49+
50+
3251
@register_xcanonicalize("shape_unsafe")
3352
@node_rewriter(tracks=[Concat])
3453
def lower_concat(fgraph, node):

pytensor/xtensor/shape.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
from pytensor import Variable
66
from pytensor.graph import Apply
7-
from pytensor.scalar import upcast
7+
from pytensor.scalar import discrete_dtypes, upcast
8+
from pytensor.tensor import as_tensor, get_scalar_constant_value
9+
from pytensor.tensor.exceptions import NotScalarConstantError
810
from pytensor.xtensor.basic import XOp
911
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
1012

@@ -75,6 +77,89 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
7577
return y
7678

7779

80+
class UnStack(XOp):
81+
__props__ = ("old_dim_name", "unstacked_dims")
82+
83+
def __init__(
84+
self,
85+
old_dim_name: str,
86+
unstacked_dims: tuple[str, ...],
87+
):
88+
super().__init__()
89+
if old_dim_name in unstacked_dims:
90+
raise ValueError(
91+
f"Dim to be unstacked {old_dim_name} can't be in {unstacked_dims}"
92+
)
93+
if not unstacked_dims:
94+
raise ValueError("Dims to unstack into can't be empty.")
95+
if len(unstacked_dims) == 1:
96+
raise ValueError("Only one dimension to unstack into, use rename instead")
97+
self.old_dim_name = old_dim_name
98+
self.unstacked_dims = unstacked_dims
99+
100+
def make_node(self, x, *unstacked_length):
101+
x = as_xtensor(x)
102+
if self.old_dim_name not in x.type.dims:
103+
raise ValueError(
104+
f"Dim to unstack {self.old_dim_name} must be in {x.type.dims}"
105+
)
106+
if not set(self.unstacked_dims).isdisjoint(x.type.dims):
107+
raise ValueError(
108+
f"Dims to unstack into {self.unstacked_dims} must not be in {x.type.dims}"
109+
)
110+
111+
if len(unstacked_length) != len(self.unstacked_dims):
112+
raise ValueError(
113+
f"Number of unstacked lengths {len(unstacked_length)} must match number of unstacked dims {len(self.unstacked_dims)}"
114+
)
115+
unstacked_lengths = [as_tensor(length, ndim=0) for length in unstacked_length]
116+
if not all(length.dtype in discrete_dtypes for length in unstacked_lengths):
117+
raise TypeError("Unstacked lengths must be discrete dtypes.")
118+
119+
if x.type.ndim == 1:
120+
batch_dims, batch_shape = (), ()
121+
else:
122+
batch_dims, batch_shape = zip(
123+
*(
124+
(dim, shape)
125+
for dim, shape in zip(x.type.dims, x.type.shape)
126+
if dim != self.old_dim_name
127+
)
128+
)
129+
130+
static_unstacked_lengths = [None] * len(unstacked_lengths)
131+
for i, length in enumerate(unstacked_lengths):
132+
try:
133+
static_length = get_scalar_constant_value(length)
134+
except NotScalarConstantError:
135+
pass
136+
else:
137+
static_unstacked_lengths[i] = int(static_length)
138+
139+
output = xtensor(
140+
dtype=x.type.dtype,
141+
shape=(*batch_shape, *static_unstacked_lengths),
142+
dims=(*batch_dims, *self.unstacked_dims),
143+
)
144+
return Apply(self, [x, *unstacked_lengths], [output])
145+
146+
147+
def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
148+
if dim is not None:
149+
if dims:
150+
raise ValueError(
151+
"Cannot use both positional dim and keyword dims in unstack"
152+
)
153+
dims = dim
154+
155+
y = x
156+
for old_dim_name, unstacked_dict in dims.items():
157+
y = UnStack(old_dim_name, tuple(unstacked_dict.keys()))(
158+
y, *tuple(unstacked_dict.values())
159+
)
160+
return y
161+
162+
78163
class Transpose(XOp):
79164
__props__ = ("dims",)
80165

pytensor/xtensor/type.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,9 @@ def rename(self, new_name_or_name_dict=None, **names):
309309
def stack(self, dim, **dims):
310310
return px.shape.stack(self, dim, **dims)
311311

312+
def unstack(self, dim, **dims):
313+
return px.shape.unstack(self, dim, **dims)
314+
312315
# def swap_dims(self, *args, **kwargs):
313316
# ...
314317
#

tests/xtensor/test_shape.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from itertools import chain, combinations
1010

1111
import numpy as np
12+
from xarray import DataArray
1213
from xarray import concat as xr_concat
1314

14-
from pytensor.xtensor.shape import concat, stack, transpose
15+
from pytensor.xtensor.shape import concat, stack, transpose, unstack
1516
from pytensor.xtensor.type import xtensor
1617
from tests.xtensor.util import (
1718
xr_arange_like,
@@ -153,6 +154,50 @@ def test_multiple_stacks():
153154
xr_assert_allclose(res[0], expected_res)
154155

155156

157+
def test_unstack_constant_size():
158+
x = xtensor("x", dims=("a", "bc", "d"), shape=(2, 3 * 5, 7))
159+
y = unstack(x, bc=dict(b=3, c=5))
160+
assert y.type.dims == ("a", "d", "b", "c")
161+
assert y.type.shape == (2, 7, 3, 5)
162+
163+
fn = xr_function([x], y)
164+
165+
x_test = xr_arange_like(x)
166+
x_np = x_test.values
167+
res = fn(x_test)
168+
expected = (
169+
DataArray(x_np.reshape(2, 3, 5, 7), dims=("a", "b", "c", "d"))
170+
.stack(bc=("b", "c"))
171+
.unstack("bc")
172+
)
173+
xr_assert_allclose(res, expected)
174+
175+
176+
def test_unstack_symbolic_size():
177+
x = xtensor(dims=("a", "b", "c"))
178+
y = stack(x, bc=("b", "c"))
179+
y = y / y.sum("bc")
180+
z = unstack(y, bc={"b": x.sizes["b"], "c": x.sizes["c"]})
181+
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5)))
182+
fn = xr_function([x], z)
183+
res = fn(x_test)
184+
b_idx, c_idx = np.unravel_index(np.arange(15)[::-1].reshape((3, 5)), (3, 5))
185+
expected_res = x_test / x_test.sum(["b", "c"])
186+
xr_assert_allclose(res, expected_res)
187+
188+
189+
def test_stack_unstack():
190+
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7))
191+
stack_x = stack(x, bd=("b", "d"))
192+
unstack_x = unstack(stack_x, bd=dict(b=3, d=7))
193+
194+
x_test = xr_arange_like(x)
195+
fn = xr_function([x], unstack_x)
196+
res = fn(x_test)
197+
expected_res = x_test.transpose("a", "c", "b", "d")
198+
xr_assert_allclose(res, expected_res)
199+
200+
156201
@pytest.mark.parametrize("dim", ("a", "b", "new"))
157202
def test_concat(dim):
158203
rng = np.random.default_rng(sum(map(ord, dim)))

0 commit comments

Comments
 (0)