Skip to content

Commit 52f9c02

Browse files
committed
Eager optimization for no-op flatten
1 parent 3c43234 commit 52f9c02

File tree

2 files changed

+8
-17
lines changed

2 files changed

+8
-17
lines changed

pytensor/tensor/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3081,6 +3081,10 @@ def flatten(x, ndim=1):
30813081
else:
30823082
dims = (-1,)
30833083

3084+
if len(dims) == _x.ndim:
3085+
# Nothing to ravel
3086+
return _x
3087+
30843088
x_reshaped = _x.reshape(dims)
30853089
shape_kept_dims = _x.type.shape[: ndim - 1]
30863090
bcast_new_dim = builtins.all(s == 1 for s in _x.type.shape[ndim - 1 :])

tests/tensor/test_basic.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3867,35 +3867,22 @@ class TestInferShape(utt.InferShapeTester):
38673867
def test_Flatten(self):
38683868
atens3 = tensor3()
38693869
atens3_val = random(4, 5, 3)
3870-
for ndim in (3, 2, 1):
3870+
for ndim in (2, 1):
38713871
self._compile_and_check(
38723872
[atens3],
38733873
[flatten(atens3, ndim)],
38743874
[atens3_val],
38753875
Reshape,
3876-
excluding=["local_useless_reshape"],
38773876
)
38783877

38793878
amat = matrix()
38803879
amat_val = random(4, 5)
3881-
for ndim in (2, 1):
3882-
self._compile_and_check(
3883-
[amat],
3884-
[flatten(amat, ndim)],
3885-
[amat_val],
3886-
Reshape,
3887-
excluding=["local_useless_reshape"],
3888-
)
3889-
3890-
avec = vector()
3891-
avec_val = random(4)
38923880
ndim = 1
38933881
self._compile_and_check(
3894-
[avec],
3895-
[flatten(avec, ndim)],
3896-
[avec_val],
3882+
[amat],
3883+
[flatten(amat, ndim)],
3884+
[amat_val],
38973885
Reshape,
3898-
excluding=["local_useless_reshape"],
38993886
)
39003887

39013888
def test_Eye(self):

0 commit comments

Comments
 (0)