Skip to content

Add transpose() for labeled tensors #1427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
25397c8
WIP Implement index operations for XTensorVariables
ricardoV94 May 21, 2025
e32d865
Add diff method to XTensorVariable
ricardoV94 May 26, 2025
5988cec
Add transpose operation for labeled tensors with ellipsis support
AllenDowney May 27, 2025
5936ab2
Refactor: Extract ellipsis expansion logic into helper function
AllenDowney May 27, 2025
6fc7b89
Fix lint errors: remove trailing whitespace from docstrings
AllenDowney May 27, 2025
0778cf7
Format files with ruff
AllenDowney May 27, 2025
c7ce0c9
Remove commented out line
AllenDowney May 27, 2025
bc2cbc0
Add missing_dims parameter to transpose for XTensorVariable and core,…
AllenDowney May 28, 2025
7bfa2b2
Add missing_dims parameter to transpose for XTensorVariable and core,…
AllenDowney May 28, 2025
d4f5512
Fix linting issues: remove unused Union import and use dict.fromkeys()
AllenDowney May 28, 2025
1ed01c4
Improve expand_ellipsis with validate parameter and update tests
AllenDowney May 28, 2025
4f010e0
Apply ruff-format to shape.py, type.py, and test_shape.py for consist…
AllenDowney May 28, 2025
f0ea583
Simplify make_node in Transpose class by combining ignore/warn cases
AllenDowney May 28, 2025
0125bd2
Format expand_ellipsis call for better readability
AllenDowney May 28, 2025
30e1a42
WIP Implement index operations for XTensorVariables
ricardoV94 May 21, 2025
29b954a
Add diff method to XTensorVariable
ricardoV94 May 26, 2025
a76b15e
Format and simplify expand_ellipsis; auto-fix with pre-commit; update…
AllenDowney May 28, 2025
af14c90
Improve expand_dims: add tests, fix reshape usage, and ensure code st…
AllenDowney May 28, 2025
6208092
Merge WIP changes from origin/labeled_tensors
AllenDowney May 28, 2025
15f4c48
Implement squeeze
AllenDowney May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
)
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
XTensorType,
as_xtensor,
xtensor,
xtensor_constant,
Expand Down
142 changes: 142 additions & 0 deletions pytensor/xtensor/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# HERE LIE DRAGONS
# Uselful links to make sense of all the numpy/xarray complexity
# https://numpy.org/devdocs//user/basics.indexing.html
# https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html

from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.scalar.basic import discrete_dtypes
from pytensor.tensor import TensorType
from pytensor.tensor.basic import as_tensor
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
from pytensor.xtensor.basic import XOp, xtensor_from_tensor
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


def as_idx_variable(idx):
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
raise TypeError(
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
)
if isinstance(idx, slice):
idx = make_slice(idx)
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
pass
elif isinstance(idx, tuple) and len(idx) == 2 and isinstance(idx[0], str):
# Special case for ("x", array) that xarray supports
# TODO: Check if this can be used to rename existing xarray dimensions or only for numpy
dim, idx = idx
idx = xtensor_from_tensor(as_tensor(idx), dims=(dim,))
else:
# Must be integer indices, we already counted for None and slices
try:
idx = as_xtensor(idx)
except TypeError:
idx = as_tensor(idx)
if idx.type.dtype == "bool":
raise NotImplementedError("Boolean indexing not yet supported")
if idx.type.dtype not in discrete_dtypes:
raise TypeError("Numerical indices must be integers or boolean")
if idx.type.dtype == "bool" and idx.type.ndim == 0:
# This can't be triggered right now, but will once we lift the boolean restriction
raise NotImplementedError("Scalar boolean indices not supported")
return idx


def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
if dim_length is None:
return None
if isinstance(slc, Constant):
d = slc.data
start, stop, step = d.start, d.stop, d.step
elif slc.owner is None:
# It's a root variable no way of knowing what we're getting
return None
else:
# It's a MakeSliceOp
start, stop, step = slc.owner.inputs
if isinstance(start, Constant):
start = start.data
else:
return None
if isinstance(stop, Constant):
stop = stop.data
else:
return None
if isinstance(step, Constant):
step = step.data
else:
return None
return len(range(*slice(start, stop, step).indices(dim_length)))


class Index(XOp):
__props__ = ()

def make_node(self, x, *idxs):
x = as_xtensor(x)
idxs = [as_idx_variable(idx) for idx in idxs]

x_ndim = x.type.ndim
x_dims = x.type.dims
x_shape = x.type.shape
out_dims = []
out_shape = []
for i, idx in enumerate(idxs):
if i == x_ndim:
raise IndexError("Too many indices")
if isinstance(idx.type, SliceType):
out_dims.append(x_dims[i])
out_shape.append(get_static_slice_length(idx, x_shape[i]))
else:
if idx.type.ndim == 0:
# Scalar index, dimension is dropped
continue

if isinstance(idx.type, TensorType):
if idx.type.ndim > 1:
# Same error that xarray raises
raise IndexError(
"Unlabeled multi-dimensional array cannot be used for indexing"
)

# This is implicitly an XTensorVariable with dim matching the indexed one
idx = idxs[i] = xtensor_from_tensor(idx, dims=(x_dims[i],))

assert isinstance(idx.type, XTensorType)

idx_dims = idx.type.dims
for dim in idx_dims:
idx_dim_shape = idx.type.shape[idx_dims.index(dim)]
if dim in out_dims:
# Dim already introduced in output by a previous index
# Update static shape or raise if incompatible
out_dim_pos = out_dims.index(dim)
out_dim_shape = out_shape[out_dim_pos]
if out_dim_shape is None:
# We don't know the size of the dimension yet
out_shape[out_dim_pos] = idx_dim_shape
elif (
idx_dim_shape is not None and idx_dim_shape != out_dim_shape
):
raise IndexError(
f"Dimension of indexers mismatch for dim {dim}"
)
else:
# New dimension
out_dims.append(dim)
out_shape.append(idx_dim_shape)

for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]):
# Add back any unindexed dimensions
if dim_i not in out_dims:
# If the dimension was not indexed, we keep it as is
out_dims.append(dim_i)
out_shape.append(shape_i)

output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output])


index = Index()
1 change: 1 addition & 0 deletions pytensor/xtensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytensor.xtensor.rewriting.basic
import pytensor.xtensor.rewriting.indexing
import pytensor.xtensor.rewriting.reduction
import pytensor.xtensor.rewriting.shape
import pytensor.xtensor.rewriting.vectorization
102 changes: 102 additions & 0 deletions pytensor/xtensor/rewriting/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from itertools import zip_longest

from pytensor import as_symbolic
from pytensor.graph import Constant, node_rewriter
from pytensor.tensor import arange, specify_shape
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.indexing import Index
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
from pytensor.xtensor.type import XTensorType


def to_basic_idx(idx):
if isinstance(idx.type, SliceType):
if isinstance(idx, Constant):
return idx.data
elif idx.owner:
# MakeSlice Op
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
return slice(
*[
None if isinstance(i.type, NoneTypeT) else i
for i in idx.owner.inputs
]
)
else:
return idx
if (
isinstance(idx.type, XTensorType)
and idx.type.ndim == 0
and idx.type.dtype != bool
):
return idx.values
raise TypeError("Cannot convert idx to basic idx")


@register_xcanonicalize
@node_rewriter(tracks=[Index])
def lower_index(fgraph, node):
x, *idxs = node.inputs
[out] = node.outputs
x_tensor = tensor_from_xtensor(x)

if all(
(
isinstance(idx.type, SliceType)
or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0)
)
for idx in idxs
):
# Special case just basic indexing
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)]

else:
# General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing
# May need to convert basic indexing to advanced indexing if it acts on a dimension
# that is also indexed by an advanced index
x_dims = x.type.dims
x_shape = tuple(x.shape)
out_ndim = out.type.ndim
out_xdims = out.type.dims
aligned_idxs = []
# zip_longest adds the implicit slice(None)
for i, (idx, x_dim) in enumerate(
zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None)))
):
if isinstance(idx.type, SliceType):
if not any(
(
isinstance(other_idx.type, XTensorType)
and x_dim in other_idx.dims
)
for j, other_idx in enumerate(idxs)
if j != i
):
# We can use basic indexing directly if no other index acts on this dimension
aligned_idxs.append(idx)
else:
# Otherwise we need to convert the basic index into an equivalent advanced indexing
# And align it so it interacts correctly with the other advanced indices
adv_idx_equivalent = arange(x_shape[i])[idx]
ds_order = ["x"] * out_ndim
ds_order[out_xdims.index(x_dim)] = 0
aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order))
else:
assert isinstance(idx.type, XTensorType)
if idx.type.ndim == 0:
# Scalar index, we can use it directly
aligned_idxs.append(idx.values)
else:
# Vector index, we need to align the indexing dimensions with the base_dims
ds_order = ["x"] * out_ndim
for j, idx_dim in enumerate(idx.dims):
ds_order[out_xdims.index(idx_dim)] = j
aligned_idxs.append(idx.values.dimshuffle(ds_order))
x_tensor_indexed = x_tensor[tuple(aligned_idxs)]
# TODO: Align output dimensions if necessary

# Add lost shape if any
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
return [new_out]
39 changes: 38 additions & 1 deletion pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytensor.tensor import broadcast_to, join, moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
from pytensor.xtensor.shape import Concat, Stack
from pytensor.xtensor.shape import Concat, ExpandDims, Squeeze, Stack, Transpose


@register_xcanonicalize
Expand Down Expand Up @@ -70,3 +70,40 @@ def lower_concat(fgraph, node):
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]


@register_xcanonicalize
@node_rewriter(tracks=[Transpose])
def lower_transpose(fgraph, node):
[x] = node.inputs
# Use the final dimensions that were already computed in make_node
out_dims = node.outputs[0].type.dims
in_dims = x.type.dims

# Compute the permutation based on the final dimensions
perm = tuple(in_dims.index(d) for d in out_dims)
x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = x_tensor.transpose(perm)
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
return [new_out]


@register_xcanonicalize
@node_rewriter(tracks=[ExpandDims])
def lower_expand_dims(fgraph, node):
[x] = node.inputs
x_tensor = tensor_from_xtensor(x)
x_tensor_expanded = x_tensor.reshape((*x_tensor.shape, 1))
new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims)
return [new_out]


@register_xcanonicalize
@node_rewriter(tracks=[Squeeze])
def lower_squeeze(fgraph, node):
[x] = node.inputs
x_tensor = tensor_from_xtensor(x)
expected_shape = tuple(node.outputs[0].type.shape)
x_tensor_squeezed = x_tensor.reshape(expected_shape)
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]
Loading