@@ -1432,6 +1432,8 @@ def test_usm_array(self):
14321432
14331433
14341434class TestTrimZeros :
1435+ ALL_TRIMS = ["F" , "B" , "fb" ]
1436+
14351437 @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
14361438 def test_basic (self , dtype ):
14371439 a = numpy .array ([0 , 0 , 1 , 0 , 2 , 3 , 4 , 0 ], dtype = dtype )
@@ -1443,7 +1445,7 @@ def test_basic(self, dtype):
14431445
14441446 @testing .with_requires ("numpy>=2.2" )
14451447 @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
1446- @pytest .mark .parametrize ("trim" , [ "F" , "B" , "fb" ] )
1448+ @pytest .mark .parametrize ("trim" , ALL_TRIMS )
14471449 @pytest .mark .parametrize ("ndim" , [0 , 1 , 2 , 3 ])
14481450 def test_basic_nd (self , dtype , trim , ndim ):
14491451 a = numpy .ones ((2 ,) * ndim , dtype = dtype )
@@ -1477,7 +1479,7 @@ def test_all_zero(self, dtype, trim):
14771479
14781480 @testing .with_requires ("numpy>=2.2" )
14791481 @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
1480- @pytest .mark .parametrize ("trim" , [ "F" , "B" , "fb" ] )
1482+ @pytest .mark .parametrize ("trim" , ALL_TRIMS )
14811483 @pytest .mark .parametrize ("ndim" , [0 , 1 , 2 , 3 ])
14821484 def test_all_zero_nd (self , dtype , trim , ndim ):
14831485 a = numpy .zeros ((3 ,) * ndim , dtype = dtype )
@@ -1496,6 +1498,51 @@ def test_size_zero(self):
14961498 expected = numpy .trim_zeros (a )
14971499 assert_array_equal (result , expected )
14981500
1501+ @testing .with_requires ("numpy>=2.4" )
1502+ @pytest .mark .parametrize (
1503+ "shape, axis" ,
1504+ [
1505+ [(5 ,), None ],
1506+ [(5 ,), ()],
1507+ [(5 ,), 0 ],
1508+ [(5 , 6 ), None ],
1509+ [(5 , 6 ), ()],
1510+ [(5 , 6 ), 0 ],
1511+ [(5 , 6 ), (- 1 ,)],
1512+ [(5 , 6 , 7 ), None ],
1513+ [(5 , 6 , 7 ), ()],
1514+ [(5 , 6 , 7 ), 1 ],
1515+ [(5 , 6 , 7 ), (0 , 2 )],
1516+ [(5 , 6 , 7 , 8 ), None ],
1517+ [(5 , 6 , 7 , 8 ), ()],
1518+ [(5 , 6 , 7 , 8 ), - 2 ],
1519+ [(5 , 6 , 7 , 8 ), (0 , 1 , 3 )],
1520+ ],
1521+ )
1522+ @pytest .mark .parametrize ("trim" , ALL_TRIMS )
1523+ def test_multiple_axes (self , shape , axis , trim ):
1524+ # standardize axis to a tuple
1525+ if axis is None :
1526+ axis = tuple (range (len (shape )))
1527+ elif isinstance (axis , int ):
1528+ axis = (len (shape ) + axis if axis < 0 else axis ,)
1529+ else :
1530+ axis = tuple (len (shape ) + ax if ax < 0 else ax for ax in axis )
1531+
1532+ # populate a random interior slice with nonzero entries
1533+ rng = numpy .random .default_rng (4321 )
1534+ a = numpy .zeros (shape )
1535+ start = rng .integers (low = 0 , high = numpy .array (shape ) - 1 )
1536+ end = rng .integers (low = start + 1 , high = shape )
1537+ shape = tuple (end - start )
1538+ data = 1 + rng .random (shape )
1539+ a [tuple (slice (i , j ) for i , j in zip (start , end ))] = data
1540+ ia = dpnp .array (a )
1541+
1542+ result = dpnp .trim_zeros (ia , axis = axis , trim = trim )
1543+ expected = numpy .trim_zeros (a , axis = axis , trim = trim )
1544+ assert_array_equal (result , expected )
1545+
14991546 @pytest .mark .parametrize (
15001547 "a" , [numpy .array ([0 , 2 ** 62 , 0 ]), numpy .array ([0 , 2 ** 63 , 0 ])]
15011548 )
0 commit comments