diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index a72bf66c79..06265e40de 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -7,7 +7,6 @@ ) from pytensor.xtensor.shape import concat from pytensor.xtensor.type import ( - XTensorType, as_xtensor, xtensor, xtensor_constant, diff --git a/pytensor/xtensor/indexing.py b/pytensor/xtensor/indexing.py new file mode 100644 index 0000000000..33eb29f051 --- /dev/null +++ b/pytensor/xtensor/indexing.py @@ -0,0 +1,142 @@ +# HERE LIE DRAGONS +# Uselful links to make sense of all the numpy/xarray complexity +# https://numpy.org/devdocs//user/basics.indexing.html +# https://numpy.org/neps/nep-0021-advanced-indexing.html +# https://docs.xarray.dev/en/latest/user-guide/indexing.html +# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html + +from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.scalar.basic import discrete_dtypes +from pytensor.tensor import TensorType +from pytensor.tensor.basic import as_tensor +from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice +from pytensor.xtensor.basic import XOp, xtensor_from_tensor +from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor + + +def as_idx_variable(idx): + if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)): + raise TypeError( + "XTensors do not support indexing with None (np.newaxis), use expand_dims instead" + ) + if isinstance(idx, slice): + idx = make_slice(idx) + elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): + pass + elif isinstance(idx, tuple) and len(idx) == 2 and isinstance(idx[0], str): + # Special case for ("x", array) that xarray supports + # TODO: Check if this can be used to rename existing xarray dimensions or only for numpy + dim, idx = idx + idx = xtensor_from_tensor(as_tensor(idx), dims=(dim,)) + else: + # Must be integer indices, we already counted for None and slices + try: + idx = as_xtensor(idx) + except TypeError: + idx = as_tensor(idx) + if idx.type.dtype == "bool": + raise NotImplementedError("Boolean indexing not yet supported") + if idx.type.dtype not in discrete_dtypes: + raise TypeError("Numerical indices must be integers or boolean") + if idx.type.dtype == "bool" and idx.type.ndim == 0: + # This can't be triggered right now, but will once we lift the boolean restriction + raise NotImplementedError("Scalar boolean indices not supported") + return idx + + +def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None: + if dim_length is None: + return None + if isinstance(slc, Constant): + d = slc.data + start, stop, step = d.start, d.stop, d.step + elif slc.owner is None: + # It's a root variable no way of knowing what we're getting + return None + else: + # It's a MakeSliceOp + start, stop, step = slc.owner.inputs + if isinstance(start, Constant): + start = start.data + else: + return None + if isinstance(stop, Constant): + stop = stop.data + else: + return None + if isinstance(step, Constant): + step = step.data + else: + return None + return len(range(*slice(start, stop, step).indices(dim_length))) + + +class Index(XOp): + __props__ = () + + def make_node(self, x, *idxs): + x = as_xtensor(x) + idxs = [as_idx_variable(idx) for idx in idxs] + + x_ndim = x.type.ndim + x_dims = x.type.dims + x_shape = x.type.shape + out_dims = [] + out_shape = [] + for i, idx in enumerate(idxs): + if i == x_ndim: + raise IndexError("Too many indices") + if isinstance(idx.type, SliceType): + out_dims.append(x_dims[i]) + out_shape.append(get_static_slice_length(idx, x_shape[i])) + else: + if idx.type.ndim == 0: + # Scalar index, dimension is dropped + continue + + if isinstance(idx.type, TensorType): + if idx.type.ndim > 1: + # Same error that xarray raises + raise IndexError( + "Unlabeled multi-dimensional array cannot be used for indexing" + ) + + # This is implicitly an XTensorVariable with dim matching the indexed one + idx = idxs[i] = xtensor_from_tensor(idx, dims=(x_dims[i],)) + + assert isinstance(idx.type, XTensorType) + + idx_dims = idx.type.dims + for dim in idx_dims: + idx_dim_shape = idx.type.shape[idx_dims.index(dim)] + if dim in out_dims: + # Dim already introduced in output by a previous index + # Update static shape or raise if incompatible + out_dim_pos = out_dims.index(dim) + out_dim_shape = out_shape[out_dim_pos] + if out_dim_shape is None: + # We don't know the size of the dimension yet + out_shape[out_dim_pos] = idx_dim_shape + elif ( + idx_dim_shape is not None and idx_dim_shape != out_dim_shape + ): + raise IndexError( + f"Dimension of indexers mismatch for dim {dim}" + ) + else: + # New dimension + out_dims.append(dim) + out_shape.append(idx_dim_shape) + + for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]): + # Add back any unindexed dimensions + if dim_i not in out_dims: + # If the dimension was not indexed, we keep it as is + out_dims.append(dim_i) + out_shape.append(shape_i) + + output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x, *idxs], [output]) + + +index = Index() diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index 7ce55b9256..a65ad0db85 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,4 +1,5 @@ import pytensor.xtensor.rewriting.basic +import pytensor.xtensor.rewriting.indexing import pytensor.xtensor.rewriting.reduction import pytensor.xtensor.rewriting.shape import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/indexing.py b/pytensor/xtensor/rewriting/indexing.py new file mode 100644 index 0000000000..3d9ac3d99b --- /dev/null +++ b/pytensor/xtensor/rewriting/indexing.py @@ -0,0 +1,102 @@ +from itertools import zip_longest + +from pytensor import as_symbolic +from pytensor.graph import Constant, node_rewriter +from pytensor.tensor import arange, specify_shape +from pytensor.tensor.type_other import NoneTypeT, SliceType +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.indexing import Index +from pytensor.xtensor.rewriting.utils import register_xcanonicalize +from pytensor.xtensor.type import XTensorType + + +def to_basic_idx(idx): + if isinstance(idx.type, SliceType): + if isinstance(idx, Constant): + return idx.data + elif idx.owner: + # MakeSlice Op + # We transform NoneConsts to regular None so that basic Subtensor can be used if possible + return slice( + *[ + None if isinstance(i.type, NoneTypeT) else i + for i in idx.owner.inputs + ] + ) + else: + return idx + if ( + isinstance(idx.type, XTensorType) + and idx.type.ndim == 0 + and idx.type.dtype != bool + ): + return idx.values + raise TypeError("Cannot convert idx to basic idx") + + +@register_xcanonicalize +@node_rewriter(tracks=[Index]) +def lower_index(fgraph, node): + x, *idxs = node.inputs + [out] = node.outputs + x_tensor = tensor_from_xtensor(x) + + if all( + ( + isinstance(idx.type, SliceType) + or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0) + ) + for idx in idxs + ): + # Special case just basic indexing + x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)] + + else: + # General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing + # May need to convert basic indexing to advanced indexing if it acts on a dimension + # that is also indexed by an advanced index + x_dims = x.type.dims + x_shape = tuple(x.shape) + out_ndim = out.type.ndim + out_xdims = out.type.dims + aligned_idxs = [] + # zip_longest adds the implicit slice(None) + for i, (idx, x_dim) in enumerate( + zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None))) + ): + if isinstance(idx.type, SliceType): + if not any( + ( + isinstance(other_idx.type, XTensorType) + and x_dim in other_idx.dims + ) + for j, other_idx in enumerate(idxs) + if j != i + ): + # We can use basic indexing directly if no other index acts on this dimension + aligned_idxs.append(idx) + else: + # Otherwise we need to convert the basic index into an equivalent advanced indexing + # And align it so it interacts correctly with the other advanced indices + adv_idx_equivalent = arange(x_shape[i])[idx] + ds_order = ["x"] * out_ndim + ds_order[out_xdims.index(x_dim)] = 0 + aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order)) + else: + assert isinstance(idx.type, XTensorType) + if idx.type.ndim == 0: + # Scalar index, we can use it directly + aligned_idxs.append(idx.values) + else: + # Vector index, we need to align the indexing dimensions with the base_dims + ds_order = ["x"] * out_ndim + for j, idx_dim in enumerate(idx.dims): + ds_order[out_xdims.index(idx_dim)] = j + aligned_idxs.append(idx.values.dimshuffle(ds_order)) + x_tensor_indexed = x_tensor[tuple(aligned_idxs)] + # TODO: Align output dimensions if necessary + + # Add lost shape if any + x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape) + new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims) + return [new_out] diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 06b8c40a32..2e9c9b6f7f 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,7 +2,7 @@ from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack +from pytensor.xtensor.shape import Concat, ExpandDims, Squeeze, Stack, Transpose @register_xcanonicalize @@ -70,3 +70,40 @@ def lower_concat(fgraph, node): joined_tensor = join(concat_axis, *bcast_tensor_inputs) new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[Transpose]) +def lower_transpose(fgraph, node): + [x] = node.inputs + # Use the final dimensions that were already computed in make_node + out_dims = node.outputs[0].type.dims + in_dims = x.type.dims + + # Compute the permutation based on the final dimensions + perm = tuple(in_dims.index(d) for d in out_dims) + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = x_tensor.transpose(perm) + new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims) + return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[ExpandDims]) +def lower_expand_dims(fgraph, node): + [x] = node.inputs + x_tensor = tensor_from_xtensor(x) + x_tensor_expanded = x_tensor.reshape((*x_tensor.shape, 1)) + new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims) + return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[Squeeze]) +def lower_squeeze(fgraph, node): + [x] = node.inputs + x_tensor = tensor_from_xtensor(x) + expected_shape = tuple(node.outputs[0].type.shape) + x_tensor_squeezed = x_tensor.reshape(expected_shape) + new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index f39d495285..4ac371a6ca 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -1,4 +1,6 @@ +import warnings from collections.abc import Sequence +from typing import Literal from pytensor import Variable from pytensor.graph import Apply @@ -73,6 +75,130 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) return y +def expand_ellipsis( + dims: tuple[str, ...], + all_dims: tuple[str, ...], + validate: bool = True, + missing_dims: Literal["raise", "warn", "ignore"] = "raise", +) -> tuple[str, ...]: + """Expand ellipsis in dimension permutation. + + Parameters + ---------- + dims : tuple[str, ...] + The dimension permutation, which may contain ellipsis + all_dims : tuple[str, ...] + All available dimensions + validate : bool, default True + Whether to check that all non-ellipsis elements in dims are valid dimension names. + missing_dims : {"raise", "warn", "ignore"}, optional + How to handle dimensions that don't exist in all_dims: + - "raise": Raise an error if any dimensions don't exist (default) + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + tuple[str, ...] + The expanded dimension permutation + + Raises + ------ + ValueError + If more than one ellipsis is present in dims. + If any non-ellipsis element in dims is not a valid dimension name and validate is True. + If missing_dims is "raise" and any dimension in dims doesn't exist in all_dims. + """ + # Handle empty or full ellipsis case + if dims == () or dims == (...,): + return tuple(reversed(all_dims)) + + # Check for multiple ellipses + if dims.count(...) > 1: + raise ValueError("an index can only have a single ellipsis ('...')") + + # Validate dimensions if requested + if validate: + invalid_dims = set(dims) - {..., *all_dims} + if invalid_dims: + if missing_dims == "raise": + raise ValueError( + f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}" + ) + elif missing_dims == "warn": + warnings.warn(f"Dimensions {invalid_dims} do not exist in {all_dims}") + + # Handle missing dimensions if not raising + if missing_dims in ("ignore", "warn"): + dims = tuple(d for d in dims if d in all_dims or d is ...) + + # If no ellipsis, just return the dimensions + if ... not in dims: + return dims + + # Handle ellipsis expansion + ellipsis_idx = dims.index(...) + pre = list(dims[:ellipsis_idx]) + post = list(dims[ellipsis_idx + 1 :]) + middle = [d for d in all_dims if d not in pre + post] + return tuple(pre + middle + post) + + +class Transpose(XOp): + __props__ = ("dims", "missing_dims") + + def __init__( + self, + dims: tuple[str | Literal[...], ...], + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + ): + super().__init__() + self.dims = dims + self.missing_dims = missing_dims + + def make_node(self, x): + x = as_xtensor(x) + dims = expand_ellipsis( + self.dims, x.type.dims, validate=True, missing_dims=self.missing_dims + ) + + output = xtensor( + dtype=x.type.dtype, + shape=tuple(x.type.shape[x.type.dims.index(d)] for d in dims), + dims=dims, + ) + return Apply(self, [x], [output]) + + +def transpose(x, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "raise"): + """Transpose dimensions of the tensor. + + Parameters + ---------- + x : XTensorVariable + Input tensor to transpose. + *dims : str + Dimensions to transpose to. Can include ellipsis (...) to represent + remaining dimensions in their original order. + missing_dims : {"raise", "warn", "ignore"}, optional + How to handle dimensions that don't exist in the input tensor: + - "raise": Raise an error if any dimensions don't exist (default) + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + XTensorVariable + Transposed tensor with reordered dimensions. + + Raises + ------ + ValueError + If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise". + """ + return Transpose(dims, missing_dims=missing_dims)(x) + + class Concat(XOp): __props__ = ("dim",) @@ -123,3 +249,119 @@ def make_node(self, *inputs: Variable) -> Apply: def concat(xtensors, dim: str): return Concat(dim=dim)(*xtensors) + + +class ExpandDims(XOp): + """Add a new dimension to an XTensorVariable. + + Parameters + ---------- + dim : str or None + The name of the new dimension. If None, the dimension will be unnamed. + """ + + def __init__(self, dim): + self.dim = dim + + def make_node(self, x): + x = as_xtensor(x) + + # Check if dimension already exists + if self.dim is not None and self.dim in x.type.dims: + raise ValueError(f"Dimension {self.dim} already exists") + + # Create new dimensions list with the new dimension + new_dims = list(x.type.dims) + new_dims.append(self.dim) + + # Create new shape with the new dimension + new_shape = list(x.type.shape) + new_shape.append(1) + + output = xtensor( + dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims) + ) + return Apply(self, [x], [output]) + + +def expand_dims(x, dim: str): + """Add a new dimension to an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + The input tensor + dim : str + The name of the new dimension + + Returns + ------- + XTensorVariable + A new tensor with the expanded dimension + """ + return ExpandDims(dim=dim)(x) + + +class Squeeze(XOp): + """Remove a dimension of size 1 from an XTensorVariable. + + Parameters + ---------- + dim : str or None + The name of the dimension to remove. If None, all dimensions of size 1 will be removed. + """ + + def __init__(self, dim=None): + self.dim = dim + + def make_node(self, x): + x = as_xtensor(x) + + # Get the index of the dimension to remove + if self.dim is not None: + if self.dim not in x.type.dims: + raise ValueError(f"Dimension {self.dim} not found") + dim_idx = x.type.dims.index(self.dim) + if x.type.shape[dim_idx] != 1: + raise ValueError( + f"Dimension {self.dim} has size {x.type.shape[dim_idx]}, not 1" + ) + else: + # Find all dimensions of size 1 + dim_idx = [i for i, s in enumerate(x.type.shape) if s == 1] + if not dim_idx: + raise ValueError("No dimensions of size 1 to remove") + + # Create new dimensions and shape lists + new_dims = list(x.type.dims) + new_shape = list(x.type.shape) + if self.dim is not None: + new_dims.pop(dim_idx) + new_shape.pop(dim_idx) + else: + # Remove all dimensions of size 1 + new_dims = [d for i, d in enumerate(new_dims) if i not in dim_idx] + new_shape = [s for i, s in enumerate(new_shape) if i not in dim_idx] + + output = xtensor( + dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims) + ) + return Apply(self, [x], [output]) + + +def squeeze(x, dim=None): + """Remove a dimension of size 1 from an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + The input tensor + dim : str or None, optional + The name of the dimension to remove. If None, all dimensions of size 1 will be removed. + + Returns + ------- + XTensorVariable + A new tensor with the specified dimension removed + """ + return Squeeze(dim=dim)(x) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 5b79e9ae57..647dc110b9 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1,3 +1,5 @@ +import warnings + from pytensor.tensor import TensorType from pytensor.tensor.math import variadic_mul @@ -10,7 +12,7 @@ XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import TypeVar +from typing import Any, Literal, TypeVar import numpy as np @@ -339,7 +341,115 @@ def sel(self, *args, **kwargs): raise NotImplementedError("sel not implemented for XTensorVariable") def __getitem__(self, idx): - raise NotImplementedError("Indexing not yet implemnented") + if isinstance(idx, dict): + return self.isel(idx) + + if not isinstance(idx, tuple): + idx = (idx,) + + # Check for ellipsis not in the last position (last one is useless anyway) + if any(idx_item is Ellipsis for idx_item in idx): + if idx.count(Ellipsis) > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Convert intermediate Ellipsis to slice(None) + ellipsis_loc = idx.index(Ellipsis) + n_implied_none_slices = self.type.ndim - (len(idx) - 1) + idx = ( + *idx[:ellipsis_loc], + *((slice(None),) * n_implied_none_slices), + *idx[ellipsis_loc + 1 :], + ) + + return px.indexing.index(self, *idx) + + def isel( + self, + indexers: dict[str, Any] | None = None, + drop: bool = False, # Unused by PyTensor + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + **indexers_kwargs, + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to isel" + ) + indexers = indexers_kwargs + + if missing_dims not in {"raise", "warn", "ignore"}: + raise ValueError( + f"Unrecognized options {missing_dims} for missing_dims argument" + ) + + # Sort indices and pass them to index + dims = self.type.dims + indices = [slice(None)] * self.type.ndim + for key, idx in indexers.items(): + if idx is Ellipsis: + # Xarray raises a less informative error, suggesting indices must be integer + # But slices are also fine + raise TypeError("Ellipsis (...) is an invalid labeled index") + try: + indices[dims.index(key)] = idx + except IndexError: + if missing_dims == "raise": + raise ValueError( + f"Dimension {key} does not exist. Expected one of {dims}" + ) + elif missing_dims == "warn": + warnings.warn( + UserWarning, + f"Dimension {key} does not exist. Expected one of {dims}", + ) + + return px.indexing.index(self, *indices) + + def _head_tail_or_thin( + self, + indexers: dict[str, Any] | int | None, + indexers_kwargs: dict[str, Any], + *, + kind: Literal["head", "tail", "thin"], + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to head" + ) + indexers = indexers_kwargs + + if indexers is None: + if kind == "thin": + raise TypeError( + "thin() indexers must be either dict-like or a single integer" + ) + else: + # Default to 5 for head and tail + indexers = {dim: 5 for dim in self.type.dims} + + elif not isinstance(indexers, dict): + indexers = {dim: indexers for dim in self.type.dims} + + if kind == "head": + indices = {dim: slice(None, value) for dim, value in indexers.items()} + elif kind == "tail": + sizes = self.sizes + # Can't use slice(-value, None), in case value is zero + indices = { + dim: slice(sizes[dim] - value, None) for dim, value in indexers.items() + } + elif kind == "thin": + indices = {dim: slice(None, None, value) for dim, value in indexers.items()} + return self.isel(indices) + + def head(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="head") + + def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="tail") + + def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin") # ndarray methods # https://docs.xarray.dev/en/latest/api.html#id7 @@ -357,6 +467,47 @@ def imag(self): def real(self): return px.math.real(self) + def transpose( + self, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "raise" + ): + """Transpose dimensions of the tensor. + + Parameters + ---------- + *dims : str + Dimensions to transpose to. Can include ellipsis (...) to represent + remaining dimensions in their original order. + missing_dims : {"raise", "warn", "ignore"}, optional + How to handle dimensions that don't exist in the input tensor: + - "raise": Raise an error if any dimensions don't exist (default) + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + XTensorVariable + Transposed tensor with reordered dimensions. + + Raises + ------ + ValueError + If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise". + """ + from pytensor.xtensor.shape import transpose + + return transpose(self, *dims, missing_dims=missing_dims) + + @property + def T(self): + """Transpose all dimensions of the tensor, reversing their order. + + Returns + ------- + XTensorVariable + Transposed tensor with reversed dimensions. + """ + return self.transpose() + # Aggregation # https://docs.xarray.dev/en/latest/api.html#id6 def all(self, dim): @@ -392,6 +543,15 @@ def cumsum(self, dim): def cumprod(self, dim): return px.reduction.cumprod(self, dim) + def diff(self, dim, n=1): + """Compute the n-th discrete difference along the given dimension.""" + slice1 = {dim: slice(1, None)} + slice2 = {dim: slice(None, -1)} + x = self + for _ in range(n): + x = x[slice1] - x[slice2] + return x + class XTensorConstantSignature(tuple): def __eq__(self, other): @@ -470,8 +630,7 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None): if isinstance(x, Apply): if len(x.outputs) != 1: raise ValueError( - "It is ambiguous which output of a " - "multi-output Op has to be fetched.", + "It is ambiguous which output of a multi-output Op has to be fetched.", x, ) else: diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py new file mode 100644 index 0000000000..8010211b23 --- /dev/null +++ b/tests/xtensor/test_indexing.py @@ -0,0 +1,138 @@ +import numpy as np +import pytest +from xarray import DataArray + +from pytensor.tensor import tensor +from pytensor.xtensor import xtensor +from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function + + +@pytest.mark.parametrize( + "indices", + [ + (0,), + (slice(1, None),), + (slice(None, -1),), + (slice(None, None, -1),), + (0, slice(None), -1, slice(1, None)), + (..., 0, -1), + (0, ..., -1), + (0, -1, ...), + ], +) +@pytest.mark.parametrize("labeled", (False, True), ids=["unlabeled", "labeled"]) +def test_basic_indexing(labeled, indices): + if ... in indices and labeled: + pytest.skip("Ellipsis not supported with labeled indexing") + + dims = ("a", "b", "c", "d") + x = xtensor(dims=dims, shape=(2, 3, 5, 7)) + + if labeled: + shufled_dims = tuple(np.random.permutation(dims)) + indices = dict(zip(shufled_dims, indices, strict=False)) + out = x[indices] + + fn = xr_function([x], out) + x_test_values = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape( + x.type.shape + ) + x_test = DataArray(x_test_values, dims=x.type.dims) + res = fn(x_test) + expected_res = x_test[indices] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_on_existing_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("a",)) + + # Three equivalent ways of indexing a->a + y = x[idx] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[idx_test] + xr_assert_allclose(res, expected_res) + + y = x[(("a", idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[(("a", idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[xidx] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_on_new_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("a",)) + + # Two equvilant ways of indexing a->new_a + y = x[(("new_a", idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[(("new_a", idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[xidx.rename(a="new_a")] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test.rename(a="new_a")] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_interacting_with_exisiting_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("a",)) + + # Two equivalent ways of indexing a->b + # By labeling the index on a, as "b", we cause pointwise indexing between the two dimensions. + y = x[(("b", idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[("b", idx_test), 1:] + xr_assert_allclose(res, expected_res) + + y = x[xidx.rename(a="b")] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test.rename(a="b"), 1:] + xr_assert_allclose(res, expected_res) + + +@pytest.mark.parametrize("n", ["implicit", 1, 2]) +@pytest.mark.parametrize("dim", ["a", "b"]) +def test_diff(dim, n): + x = xtensor(dims=("a", "b"), shape=(7, 11)) + if n == "implicit": + out = x.diff(dim) + else: + out = x.diff(dim, n=n) + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + res = fn(x_test) + if n == "implicit": + expected_res = x_test.diff(dim) + else: + expected_res = x_test.diff(dim, n=n) + xr_assert_allclose(res, expected_res) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 79cc2738a2..57f182650f 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -1,18 +1,21 @@ # ruff: noqa: E402 +import numpy as np import pytest +from xarray import concat as xr_concat pytest.importorskip("xarray") from itertools import chain, combinations -import numpy as np -from xarray import DataArray -from xarray import concat as xr_concat - -from pytensor.xtensor.shape import concat, stack +from pytensor.xtensor.shape import concat, expand_dims, squeeze, stack, transpose from pytensor.xtensor.type import xtensor -from tests.xtensor.util import xr_assert_allclose, xr_function, xr_random_like +from tests.xtensor.util import ( + xr_arange_like, + xr_assert_allclose, + xr_function, + xr_random_like, +) def powerset(iterable, min_group_size=0): @@ -24,9 +27,7 @@ def powerset(iterable, min_group_size=0): ) -@pytest.mark.xfail(reason="Not yet implemented") def test_transpose(): - transpose = None a, b, c, d, e = "abcde" x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11)) @@ -42,10 +43,7 @@ def test_transpose(): outs = [transpose(x, *perm) for perm in permutations] fn = xr_function([x], outs) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = [x_test.transpose(*perm) for perm in permutations] for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): @@ -61,10 +59,7 @@ def test_stack(): ] fn = xr_function([x], outs) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = [ @@ -81,10 +76,7 @@ def test_stack_single_dim(): assert out.type.dims == ("b", "c", "d") fn = xr_function([x], out) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) fn.fn.dprint(print_type=True) res = fn(x_test) expected_res = x_test.stack(d=["a"]) @@ -96,10 +88,7 @@ def test_multiple_stacks(): out = stack(x, new_dim1=("a", "b"), new_dim2=("c", "d")) fn = xr_function([x], [out]) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d")) xr_assert_allclose(res[0], expected_res) @@ -163,3 +152,163 @@ def test_concat_scalar(): res = fn(x1_test, x2_test) expected_res = xr_concat([x1_test, x2_test], dim="new_dim") xr_assert_allclose(res, expected_res) + + +def test_xtensor_variable_transpose(): + """Test the transpose() method of XTensorVariable.""" + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + + # Test basic transpose + out = x.transpose() + fn = xr_function([x], out) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.transpose()) + + # Test transpose with specific dimensions + out = x.transpose("c", "a", "b") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b")) + + # Test transpose with ellipsis + out = x.transpose("c", ...) + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + # Test error cases + with pytest.raises( + ValueError, + match="Invalid dimensions: {'d'}. Available dimensions: \\('a', 'b', 'c'\\)", + ): + x.transpose("d") + + with pytest.raises(ValueError, match="an index can only have a single ellipsis"): + x.transpose("a", ..., "b", ...) + + # Test missing_dims parameter + # Test ignore + out = x.transpose("c", ..., "d", missing_dims="ignore") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + # Test warn + with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"): + out = x.transpose("c", ..., "d", missing_dims="warn") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + +def test_xtensor_variable_T(): + """Test the T property of XTensorVariable.""" + # Test T property with 3D tensor + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + out = x.T + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.transpose()) + + # Test T property with 2D tensor + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + out = x.T + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.transpose()) + + +def test_expand_dims(): + # Test 1D tensor expansion + x = xtensor("x", dims=("city",), shape=(3,)) + y = expand_dims(x, "country") + assert y.type.dims == ("city", "country") + assert y.type.shape == (3, 1) + + # Test 2D tensor expansion + x2d = xtensor("x2d", dims=("row", "col"), shape=(2, 3)) + y2d = expand_dims(x2d, "batch") + assert y2d.type.dims == ("row", "col", "batch") + assert y2d.type.shape == (2, 3, 1) + + # Test expansion with different dimension name + z = expand_dims(x, "time") + assert z.type.dims == ("city", "time") + assert z.type.shape == (3, 1) + + # Test that expanding with an existing dimension raises an error + with pytest.raises(ValueError): + expand_dims(y, "city") + + # Test that expanding with None dimension works + z = expand_dims(x, None) + assert z.type.dims == ("city", None) + assert z.type.shape == (3, 1) + + +def test_squeeze(): + # Test 1D tensor with no squeezable dimensions + x = xtensor("x", dims=("city",), shape=(3,)) + with pytest.raises(ValueError, match="No dimensions of size 1 to remove"): + squeeze(x) + + # Test 2D tensor with one squeezable dimension + x2d = xtensor("x2d", dims=("row", "col"), shape=(2, 1)) + y2d = squeeze(x2d) + assert y2d.type.dims == ("row",) + assert y2d.type.shape == (2,) + + # Test 3D tensor with multiple squeezable dimensions + x3d = xtensor("x3d", dims=("batch", "row", "col"), shape=(1, 2, 1)) + y3d = squeeze(x3d) + assert y3d.type.dims == ("row",) + assert y3d.type.shape == (2,) + + # Test squeezing specific dimension + x3d = xtensor("x3d", dims=("batch", "row", "col"), shape=(1, 2, 1)) + y3d = squeeze(x3d, dim="batch") + assert y3d.type.dims == ("row", "col") + assert y3d.type.shape == (2, 1) + + # Test squeezing non-existent dimension + with pytest.raises(ValueError, match="Dimension time not found"): + squeeze(x3d, dim="time") + + # Test squeezing dimension with size > 1 + x3d = xtensor("x3d", dims=("batch", "row", "col"), shape=(2, 2, 1)) + with pytest.raises(ValueError, match="Dimension batch has size 2, not 1"): + squeeze(x3d, dim="batch") + + # Test functional interface + fn = xr_function([x2d], y2d) + x_test = xr_arange_like(x2d) + res = fn(x_test) + expected_res = x_test.squeeze() + xr_assert_allclose(res, expected_res) + + # Test squeezing a tensor with multiple squeezable dimensions + x_multi = xtensor("x_multi", dims=("batch", "row", "col"), shape=(1, 2, 1)) + y_multi = squeeze(x_multi) + assert y_multi.type.dims == ("row",) + assert y_multi.type.shape == (2,) + fn_multi = xr_function([x_multi], y_multi) + x_multi_test = xr_arange_like(x_multi) + res_multi = fn_multi(x_multi_test) + expected_res_multi = x_multi_test.squeeze() + xr_assert_allclose(res_multi, expected_res_multi) + + +def test_lower_squeeze(): + from pytensor.xtensor.rewriting.shape import lower_squeeze + from pytensor.xtensor.shape import squeeze + from pytensor.xtensor.type import xtensor + + # Create a tensor with a squeezable dimension + x = xtensor("x", dims=("row", "col"), shape=(2, 1)) + y = squeeze(x) + + class DummyFGraph: + pass + + node = type("Node", (), {"inputs": [x], "op": y.owner.op, "outputs": [y]})() + [out] = lower_squeeze.transform(DummyFGraph(), node) + assert out.type.dims == ("row",) + assert out.type.shape == (2,)