Skip to content

Commit 7633a48

Browse files
committed
Implement proper type.filter
1 parent bde593a commit 7633a48

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

pytensor/xtensor/type.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def __init__(
4949
self.shape = tuple(shape)
5050
self.ndim = len(self.dims)
5151
self.name = name
52+
self.numpy_dtype = np.dtype(self.dtype)
53+
self.filter_checks_isfinite = False
5254

5355
def clone(
5456
self,
@@ -66,8 +68,9 @@ def clone(
6668
return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs)
6769

6870
def filter(self, value, strict=False, allow_downcast=None):
69-
# TODO implement this
70-
return value
71+
return TensorType.filter(
72+
self, value, strict=strict, allow_downcast=allow_downcast
73+
)
7174

7275
def convert_variable(self, var):
7376
# TODO: Implement this
@@ -530,16 +533,19 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
530533
if isinstance(x.type, XTensorType):
531534
return x
532535
if isinstance(x.type, TensorType):
533-
if x.type.ndim > 0 and dims is None:
534-
raise TypeError(
535-
"non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
536-
)
537-
return px.basic.xtensor_from_tensor(x, dims)
536+
if dims is None:
537+
if x.type.ndim == 0:
538+
dims = ()
539+
else:
540+
raise TypeError(
541+
"non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
542+
)
543+
return px.basic.xtensor_from_tensor(x, dims=dims, name=name)
538544
else:
539545
raise TypeError(
540546
"Variable with type {x.type} cannot be converted to XTensorVariable."
541547
)
542548
try:
543-
return xtensor_constant(x, name=name, dims=dims)
549+
return xtensor_constant(x, dims=dims, name=name)
544550
except TypeError as err:
545551
raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err

0 commit comments

Comments
 (0)