Skip to content

Commit fb31b6a

Browse files
authored
Merge pull request data-apis#356 from cakedev0/fix/torch-argsort-stability
FIX: Wrap torch.argsort to set stable=True by default
2 parents a65cbc4 + 1fafdda commit fb31b6a

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,20 @@ def sort(
241241
) -> Array:
242242
return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values
243243

244+
245+
# Wrap torch.argsort to set stable=True by default
246+
def argsort(
247+
x: Array,
248+
/,
249+
*,
250+
axis: int = -1,
251+
descending: bool = False,
252+
stable: bool = True,
253+
**kwargs: object,
254+
) -> Array:
255+
return torch.argsort(x, dim=axis, descending=descending, stable=stable, **kwargs)
256+
257+
244258
def _normalize_axes(axis, ndim):
245259
axes = []
246260
if ndim == 0 and axis:
@@ -837,9 +851,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array
837851
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
838852
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
839853
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',
840-
'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum',
841-
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
842-
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
854+
'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort',
855+
'argsort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat',
856+
'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
843857
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
844858
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
845859
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',

array_api_compat/torch/fft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Sequence
44
from typing import Literal
55

6-
import torch
6+
import torch # noqa: F401
77
import torch.fft
88

99
from ._typing import Array

0 commit comments

Comments
 (0)