Skip to content
This repository was archived by the owner on Feb 17, 2021. It is now read-only.

Commit 47a5d0c

Browse files
committed
Add isfinite, isinf, isscalar, diagonal, allclose, and improve isnan support in NumPy stubs
also add support for scalar + array fixes #209
1 parent 44dd95a commit 47a5d0c

File tree

2 files changed

+137
-4
lines changed

2 files changed

+137
-4
lines changed

numpy-stubs/__init__.pyi

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ class ndarray(Generic[_DType]):
390390
def __radd__(self, value: ndarray[_DType]) -> ndarray[_DType]: ...
391391
@overload
392392
def __radd__(self, value: _DType) -> ndarray[_DType]: ...
393+
@overload
394+
def __radd__(self, value: float) -> ndarray[_DType]: ...
393395
def __rand__(self, value: object) -> ndarray[_DType]: ...
394396
def __rdivmod__(self, value: object) -> Tuple[ndarray[_DType], ndarray[_DType]]: ...
395397
def __rfloordiv__(self, value: object) -> ndarray[_DType]: ...
@@ -774,10 +776,6 @@ def interp(
774776
) -> ndarray: ...
775777
def isin(element: Sequence[_DType], test_element: _DType) -> ndarray[_DType]: ...
776778
@overload
777-
def isnan(x: float64) -> bool: ...
778-
@overload
779-
def isnan(x: ndarray[_DType]) -> ndarray[bool_]: ...
780-
@overload
781779
def ix_(x: ndarray[_DType]) -> ndarray[_DType]: ...
782780
@overload
783781
def ix_(x1: ndarray[_DType], x2: ndarray[_DType]) -> Tuple[ndarray[_DType], ndarray[_DType]]: ...
@@ -993,6 +991,92 @@ def set_printoptions(
993991
*,
994992
legacy: Any = ...,
995993
) -> None: ...
994+
def isscalar(element: Any) -> bool: ...
995+
def diagonal(a: _ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ...) -> ndarray: ...
996+
def allclose(
997+
a: Union[_ArrayLike, _FloatLike],
998+
b: Union[_ArrayLike, _FloatLike],
999+
rtol: float = ...,
1000+
atol: float = ...,
1001+
equal_nan: bool = ...,
1002+
) -> bool: ...
1003+
1004+
#
1005+
# ufunc
1006+
#
1007+
1008+
# Backported from latest NumPy
1009+
class ufunc:
1010+
@property
1011+
def __name__(self) -> str: ...
1012+
def __call__(
1013+
self,
1014+
*args: Union[_FloatLike, _ArrayLike],
1015+
out: Optional[Union[ndarray, Tuple[ndarray, ...]]] = ...,
1016+
where: Optional[ndarray] = ...,
1017+
# The list should be a list of tuples of ints, but since we
1018+
# don't know the signature it would need to be
1019+
# Tuple[int, ...]. But, since List is invariant something like
1020+
# e.g. List[Tuple[int, int]] isn't a subtype of
1021+
# List[Tuple[int, ...]], so we can't type precisely here.
1022+
axes: List[Any] = ...,
1023+
axis: int = ...,
1024+
keepdims: bool = ...,
1025+
casting: Any = ...,
1026+
order: Any = ...,
1027+
dtype: Any = ...,
1028+
subok: bool = ...,
1029+
signature: Union[str, Tuple[str]] = ...,
1030+
# In reality this should be a length of list 3 containing an
1031+
# int, an int, and a callable, but there's no way to express
1032+
# that.
1033+
extobj: List[Union[int, Callable]] = ...,
1034+
) -> Any: ...
1035+
@property
1036+
def nin(self) -> int: ...
1037+
@property
1038+
def nout(self) -> int: ...
1039+
@property
1040+
def nargs(self) -> int: ...
1041+
@property
1042+
def ntypes(self) -> int: ...
1043+
@property
1044+
def types(self) -> List[str]: ...
1045+
# Broad return type because it has to encompass things like
1046+
#
1047+
# >>> np.logical_and.identity is True
1048+
# True
1049+
# >>> np.add.identity is 0
1050+
# True
1051+
# >>> np.sin.identity is None
1052+
# True
1053+
#
1054+
# and any user-defined ufuncs.
1055+
@property
1056+
def identity(self) -> Any: ...
1057+
# This is None for ufuncs and a string for gufuncs.
1058+
@property
1059+
def signature(self) -> Optional[str]: ...
1060+
# The next four methods will always exist, but they will just
1061+
# raise a ValueError ufuncs with that don't accept two input
1062+
# arguments and return one output argument. Because of that we
1063+
# can't type them very precisely.
1064+
@property
1065+
def reduce(self) -> Any: ...
1066+
@property
1067+
def accumulate(self) -> Any: ...
1068+
@property
1069+
def reduceat(self) -> Any: ...
1070+
@property
1071+
def outer(self) -> Any: ...
1072+
# Similarly at won't be defined for ufuncs that return multiple
1073+
# outputs, so we can't type it very precisely.
1074+
@property
1075+
def at(self) -> Any: ...
1076+
1077+
isfinite: ufunc
1078+
isinf: ufunc
1079+
isnan: ufunc
9961080

9971081
#
9981082
# Specific values

tests/numpy_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,52 @@ def test_interp() -> None:
259259
def test_genfromtxt() -> None:
260260
result = np.genfromtxt(["0.1, 0.2"], dtype=np.float64, delimiter=",")
261261
assert list(result) == [0.1, 0.2]
262+
263+
264+
def test_isfinite_isinf_isnan() -> None:
265+
import math
266+
267+
assert np.isfinite(0.0)
268+
assert not np.isfinite(np.inf)
269+
assert np.isinf(np.inf)
270+
assert not np.isfinite(math.inf)
271+
assert not np.isfinite(np.nan)
272+
assert np.isnan(np.nan)
273+
assert np.all(np.isfinite([0.0, -np.inf]) == [True, False])
274+
assert np.all(np.isfinite(np.array([0.0, np.nan])) == np.array([True, False]))
275+
assert np.all(
276+
np.isfinite(np.array([np.inf, np.nan], dtype=np.float32)) == np.array([False, False])
277+
)
278+
assert np.all(np.isnan([0.0, -np.inf]) == [False, False])
279+
assert np.all(np.isinf([0.0, -np.inf]) == [False, True])
280+
281+
282+
def test_diagonal() -> None:
283+
assert np.all(np.diagonal([[1]]) == np.array([1]))
284+
x = np.arange(12).reshape(3, 4)
285+
assert np.all(np.diagonal(x) == np.array([0.0, 5.0, 10.0]))
286+
287+
288+
def test_allclose() -> None:
289+
assert np.allclose([1.0, 2.0], [1.0 + 1e-9, 2.0 + 1e-9])
290+
assert np.allclose(np.array([1.0, 2.0]), np.array([1.0 + 1e-9, 2.0 + 1e-9]))
291+
assert np.allclose(np.array([1.0, 1.0]), 1.0 + 1e-9)
292+
assert np.allclose(1.0 + 1e-9, np.array([1.0, 1.0]))
293+
294+
295+
def test_isscalar() -> None:
296+
assert np.isscalar(1.0)
297+
assert not np.isscalar([1.0])
298+
assert not np.isscalar(np.array([1.0]))
299+
assert not np.isscalar(np.array([]))
300+
assert np.isscalar(np.array([1.0], dtype=np.float32)[0])
301+
302+
303+
def test_newaxis() -> None:
304+
x = np.array([1.0, 2.0])
305+
assert x[np.newaxis, :].shape == (1, 2)
306+
307+
308+
def test_sum_scalar_before() -> None:
309+
x = 273.15 + np.array([-0.1e2, -0.77e1])
310+
assert isinstance(x, np.ndarray)

0 commit comments

Comments
 (0)