File tree Expand file tree Collapse file tree 2 files changed +8
-17
lines changed Expand file tree Collapse file tree 2 files changed +8
-17
lines changed Original file line number Diff line number Diff line change @@ -3081,6 +3081,10 @@ def flatten(x, ndim=1):
3081
3081
else :
3082
3082
dims = (- 1 ,)
3083
3083
3084
+ if len (dims ) == _x .ndim :
3085
+ # Nothing to ravel
3086
+ return _x
3087
+
3084
3088
x_reshaped = _x .reshape (dims )
3085
3089
shape_kept_dims = _x .type .shape [: ndim - 1 ]
3086
3090
bcast_new_dim = builtins .all (s == 1 for s in _x .type .shape [ndim - 1 :])
Original file line number Diff line number Diff line change @@ -3867,35 +3867,22 @@ class TestInferShape(utt.InferShapeTester):
3867
3867
def test_Flatten (self ):
3868
3868
atens3 = tensor3 ()
3869
3869
atens3_val = random (4 , 5 , 3 )
3870
- for ndim in (3 , 2 , 1 ):
3870
+ for ndim in (2 , 1 ):
3871
3871
self ._compile_and_check (
3872
3872
[atens3 ],
3873
3873
[flatten (atens3 , ndim )],
3874
3874
[atens3_val ],
3875
3875
Reshape ,
3876
- excluding = ["local_useless_reshape" ],
3877
3876
)
3878
3877
3879
3878
amat = matrix ()
3880
3879
amat_val = random (4 , 5 )
3881
- for ndim in (2 , 1 ):
3882
- self ._compile_and_check (
3883
- [amat ],
3884
- [flatten (amat , ndim )],
3885
- [amat_val ],
3886
- Reshape ,
3887
- excluding = ["local_useless_reshape" ],
3888
- )
3889
-
3890
- avec = vector ()
3891
- avec_val = random (4 )
3892
3880
ndim = 1
3893
3881
self ._compile_and_check (
3894
- [avec ],
3895
- [flatten (avec , ndim )],
3896
- [avec_val ],
3882
+ [amat ],
3883
+ [flatten (amat , ndim )],
3884
+ [amat_val ],
3897
3885
Reshape ,
3898
- excluding = ["local_useless_reshape" ],
3899
3886
)
3900
3887
3901
3888
def test_Eye (self ):
You can’t perform that action at this time.
0 commit comments