@@ -49,6 +49,8 @@ def __init__(
49
49
self .shape = tuple (shape )
50
50
self .ndim = len (self .dims )
51
51
self .name = name
52
+ self .numpy_dtype = np .dtype (self .dtype )
53
+ self .filter_checks_isfinite = False
52
54
53
55
def clone (
54
56
self ,
@@ -66,8 +68,9 @@ def clone(
66
68
return type (self )(dtype = dtype , shape = shape , dims = dims , ** kwargs )
67
69
68
70
def filter (self , value , strict = False , allow_downcast = None ):
69
- # TODO implement this
70
- return value
71
+ return TensorType .filter (
72
+ self , value , strict = strict , allow_downcast = allow_downcast
73
+ )
71
74
72
75
def convert_variable (self , var ):
73
76
# TODO: Implement this
@@ -530,16 +533,19 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
530
533
if isinstance (x .type , XTensorType ):
531
534
return x
532
535
if isinstance (x .type , TensorType ):
533
- if x .type .ndim > 0 and dims is None :
534
- raise TypeError (
535
- "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
536
- )
537
- return px .basic .xtensor_from_tensor (x , dims )
536
+ if dims is None :
537
+ if x .type .ndim == 0 :
538
+ dims = ()
539
+ else :
540
+ raise TypeError (
541
+ "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
542
+ )
543
+ return px .basic .xtensor_from_tensor (x , dims = dims , name = name )
538
544
else :
539
545
raise TypeError (
540
546
"Variable with type {x.type} cannot be converted to XTensorVariable."
541
547
)
542
548
try :
543
- return xtensor_constant (x , name = name , dims = dims )
549
+ return xtensor_constant (x , dims = dims , name = name )
544
550
except TypeError as err :
545
551
raise TypeError (f"Cannot convert { x } to XTensorType { type (x )} " ) from err
0 commit comments