Skip to content

Commit 16978e6

Browse files
authored
Merge pull request #305 from ev-br/torch_repeat
BUG: add torch.repeat
2 parents 2f8e63a + 00e7cce commit 16978e6

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

array_api_compat/torch/_aliases.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,11 @@ def count_nonzero(
555555
return result
556556

557557

558+
# "repeat" is torch.repeat_interleave; also the dim argument
559+
def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array:
560+
return torch.repeat_interleave(x, repeats, axis)
561+
562+
558563
def where(
559564
condition: Array,
560565
x1: Array | bool | int | float | complex,
@@ -835,6 +840,6 @@ def sign(x: Array, /) -> Array:
835840
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
836841
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
837842
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
838-
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo']
843+
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat']
839844

840845
_all_ignore = ['torch', 'get_xp']

torch-xfails.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,8 @@ array_api_tests/test_data_type_functions.py::test_finfo_dtype
124124
array_api_tests/test_data_type_functions.py::test_iinfo_dtype
125125

126126
# 2023.12 support
127-
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
127+
# https://github.com/pytorch/pytorch/issues/151311: torch.repeat_interleave rejects short integers
128128
array_api_tests/test_manipulation_functions.py::test_repeat
129-
array_api_tests/test_signatures.py::test_func_signature[repeat]
130129
# Argument 'device' missing from signature
131130
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
132131
# Argument 'max_version' missing from signature

0 commit comments

Comments
 (0)