diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 10cc1c9e8..684f05a3d 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -4,7 +4,7 @@ from collections import OrderedDict from collections.abc import Callable -from typing import Any, Optional, Protocol, Union, runtime_checkable +from typing import Any, Optional, Protocol, TypeVar, Union, runtime_checkable import weakref try: @@ -12,6 +12,7 @@ except ImportError: from typing_extensions import ParamSpec, TypeAlias + import numpy as np import jax @@ -36,6 +37,9 @@ """A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays.""" +NumLikeT = TypeVar("NumLikeT", bound=NumLike) + + @runtime_checkable class ConstraintT(Protocol): """A protocol for typing constraints.""" @@ -45,10 +49,10 @@ def is_discrete(self) -> bool: ... @property def event_dim(self) -> int: ... - def __call__(self, x: ArrayLike) -> ArrayLike: ... + def __call__(self, x: NumLike) -> ArrayLike: ... def __repr__(self) -> str: ... - def check(self, value: ArrayLike) -> ArrayLike: ... - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: ... + def check(self, value: NumLike) -> ArrayLike: ... + def feasible_like(self, prototype: NumLike) -> NumLike: ... @runtime_checkable diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 417b5f260..f7c1e846a 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -64,7 +64,7 @@ ] import math -from typing import Optional +from typing import Generic, Optional import numpy as np @@ -73,10 +73,10 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, NonScalarArray, NumLike +from numpyro._typing import ConstraintT, NonScalarArray, NumLike, NumLikeT -class Constraint(object): +class Constraint(Generic[NumLikeT]): """ Abstract base class for constraints. @@ -84,27 +84,27 @@ class Constraint(object): e.g. within which a variable can be optimized. """ - is_discrete = False - event_dim = 0 + is_discrete: bool = False + event_dim: int = 0 def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLikeT) -> ArrayLike: raise NotImplementedError def __repr__(self) -> str: return self.__class__.__name__[1:] + "()" - def check(self, value: ArrayLike) -> ArrayLike: + def check(self, value: NumLikeT) -> ArrayLike: """ Returns a byte tensor of `sample_shape + batch_shape` indicating whether each event in value satisfies this constraint. """ return self(value) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLikeT) -> NumLikeT: """ Get a feasible value which has the same shape as dtype as `prototype`. """ @@ -122,12 +122,12 @@ def tree_unflatten(cls, aux_data, params): return self -class ParameterFreeConstraint(Constraint): +class ParameterFreeConstraint(Constraint[NumLikeT]): def tree_flatten(self): return (), ((), dict()) -class _SingletonConstraint(ParameterFreeConstraint): +class _SingletonConstraint(ParameterFreeConstraint[NumLikeT]): """ A constraint type which has only one canonical instance, like constraints.real, and unlike constraints.interval. @@ -140,29 +140,28 @@ def __new__(cls): return cls._instance -class _Boolean(_SingletonConstraint): +class _Boolean(_SingletonConstraint[NumLike]): is_discrete = True - def __call__(self, x: ArrayLike) -> ArrayLike: - return (x == 0) | (x == 1) + def __call__(self, x: NumLike) -> ArrayLike: + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.equal(x, 0) | xp.equal(x, 1) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) -class _CorrCholesky(_SingletonConstraint): +class _CorrCholesky(_SingletonConstraint[NonScalarArray]): event_dim = 2 - def __call__(self, x: NonScalarArray) -> NonScalarArray: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - tril = jnp.tril(x) - lower_triangular = jnp.all( - jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1 - ) - positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) - x_norm = jnp.linalg.norm(x, axis=-1) - tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10 - unit_norm_row = jnp.all(jnp.abs(x_norm - 1) <= tol, axis=-1) + def __call__(self, x: NonScalarArray) -> ArrayLike: + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + tril = xp.tril(x) + lower_triangular = xp.all(xp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1) + positive_diagonal = xp.all(xp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) + x_norm = xp.linalg.norm(x, axis=-1) + tol = xp.finfo(x.dtype).eps * x.shape[-1] * 10 + unit_norm_row = xp.all(xp.abs(x_norm - 1) <= tol, axis=-1) return lower_triangular & positive_diagonal & unit_norm_row def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: @@ -171,18 +170,18 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _CorrMatrix(_SingletonConstraint): +class _CorrMatrix(_SingletonConstraint[NonScalarArray]): event_dim = 2 - def __call__(self, x: NonScalarArray) -> NonScalarArray: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + def __call__(self, x: NonScalarArray) -> ArrayLike: + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric - symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) + symmetric = xp.all(xp.isclose(x, xp.swapaxes(x, -2, -1)), axis=(-2, -1)) # check for the smallest eigenvalue is positive - positive = jnp.linalg.eigvalsh(x)[..., 0] > 0 + positive = xp.linalg.eigvalsh(x)[..., 0] > 0 # check for diagonal equal to 1 - unit_variance = jnp.all( - jnp.abs(jnp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1 + unit_variance = xp.all( + xp.abs(xp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1 ) return symmetric & positive & unit_variance @@ -192,7 +191,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _Dependent(Constraint): +class _Dependent(Constraint[NumLikeT]): """ Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints. @@ -219,14 +218,14 @@ def is_discrete(self): return self._is_discrete @property - def event_dim(self) -> int: + def event_dim(self) -> int: # type: ignore[override] if self._event_dim is NotImplemented: raise NotImplementedError(".event_dim cannot be determined statically") return self._event_dim def __call__( self, - x: Optional[ArrayLike] = None, + x: Optional[NumLikeT] = None, *, is_discrete: bool = NotImplemented, event_dim: int = NotImplemented, @@ -242,21 +241,22 @@ def __call__( event_dim = self._event_dim return _Dependent(is_discrete=is_discrete, event_dim=event_dim) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, _Dependent): + return False return ( - type(self) is type(other) - and self._is_discrete == other._is_discrete + self._is_discrete == other._is_discrete and self._event_dim == other._event_dim ) def tree_flatten(self): return (), ( (), - dict(_is_discrete=self._is_discrete, _event_dim=self._event_dim), + dict(_is_discrete=self._is_discrete, _event_dim=self.event_dim), ) -class dependent_property(property, _Dependent): +class dependent_property(property, _Dependent[NumLikeT]): # XXX: this should not need to be pytree-able since it simply wraps a method # and thus is automatically present once the method's object is created def __init__( @@ -264,9 +264,9 @@ def __init__( ): super().__init__(fn) self._is_discrete = is_discrete - self._event_dim = event_dim + self.event_dim = event_dim - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLikeT) -> ArrayLike: if not callable(x): return super().__call__(x) @@ -275,7 +275,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: # def support(self): # ... return dependent_property( - x, is_discrete=self._is_discrete, event_dim=self._event_dim + x, is_discrete=self._is_discrete, event_dim=self.event_dim ) @@ -283,12 +283,13 @@ def is_dependent(constraint): return isinstance(constraint, _Dependent) -class _GreaterThan(Constraint): +class _GreaterThan(Constraint[NumLike]): def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound def __call__(self, x: NumLike) -> ArrayLike: - return x > self.lower_bound + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.greater(x, self.lower_bound) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -301,40 +302,41 @@ def feasible_like(self, prototype: NumLike) -> NumLike: def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _GreaterThan): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return jnp.array_equal(self.lower_bound, other.lower_bound) # type: ignore[return-value] class _GreaterThanEq(_GreaterThan): def __call__(self, x: NumLike) -> ArrayLike: - return x >= self.lower_bound + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.greater_equal(x, self.lower_bound) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _GreaterThanEq): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return jnp.array_equal(self.lower_bound, other.lower_bound) # type: ignore[return-value] -class _Positive(_SingletonConstraint, _GreaterThan): +class _Positive(_SingletonConstraint[NumLike], _GreaterThan): def __init__(self) -> None: super().__init__(0.0) -class _Nonnegative(_SingletonConstraint, _GreaterThanEq): +class _Nonnegative(_SingletonConstraint[NumLike], _GreaterThanEq): def __init__(self) -> None: super().__init__(0.0) -class _IndependentConstraint(Constraint): +class _IndependentConstraint(Constraint[NumLikeT]): """ Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`, so that an event is valid only if all its independent entries are valid. """ - def __init__(self, base_constraint, reinterpreted_batch_ndims): + def __init__(self, base_constraint: ConstraintT, reinterpreted_batch_ndims: int): assert isinstance(base_constraint, Constraint) assert isinstance(reinterpreted_batch_ndims, int) assert reinterpreted_batch_ndims >= 0 @@ -343,19 +345,13 @@ def __init__(self, base_constraint, reinterpreted_batch_ndims): reinterpreted_batch_ndims + base_constraint.reinterpreted_batch_ndims ) base_constraint = base_constraint.base_constraint - self.base_constraint = base_constraint - self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + self.base_constraint: Constraint = base_constraint + self.reinterpreted_batch_ndims: int = reinterpreted_batch_ndims + self.is_discrete = base_constraint.is_discrete + self.event_dim = base_constraint.event_dim + reinterpreted_batch_ndims super().__init__() - @property - def is_discrete(self) -> bool: - return self.base_constraint.is_discrete - - @property - def event_dim(self) -> int: - return self.base_constraint.event_dim + self.reinterpreted_batch_ndims - - def __call__(self, value: ArrayLike) -> ArrayLike: + def __call__(self, value: NumLikeT) -> ArrayLike: result = self.base_constraint(value) if self.reinterpreted_batch_ndims == 0: return result @@ -364,11 +360,12 @@ def __call__(self, value: ArrayLike) -> ArrayLike: raise ValueError( f"Expected value.dim() >= {expected} but got {jax.numpy.ndim(value)}" ) - result = result.reshape( + result = jnp.reshape( + result, jax.numpy.shape(result)[ : jax.numpy.ndim(result) - self.reinterpreted_batch_ndims ] - + (-1,) + + (-1,), ) result = result.all(-1) return result @@ -380,7 +377,7 @@ def __repr__(self) -> str: self.reinterpreted_batch_ndims, ) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLikeT) -> NumLikeT: return self.base_constraint.feasible_like(prototype) def tree_flatten(self): @@ -389,7 +386,7 @@ def tree_flatten(self): {"reinterpreted_batch_ndims": self.reinterpreted_batch_ndims}, ) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _IndependentConstraint): return False @@ -398,22 +395,27 @@ def __eq__(self, other: ConstraintT) -> bool: ) -class _RealVector(_IndependentConstraint, _SingletonConstraint): +class _RealVector( + _IndependentConstraint[NonScalarArray], _SingletonConstraint[NonScalarArray] +): def __init__(self) -> None: super().__init__(_Real(), 1) -class _RealMatrix(_IndependentConstraint, _SingletonConstraint): +class _RealMatrix( + _IndependentConstraint[NonScalarArray], _SingletonConstraint[NonScalarArray] +): def __init__(self) -> None: super().__init__(_Real(), 2) -class _LessThan(Constraint): +class _LessThan(Constraint[NumLike]): def __init__(self, upper_bound: NumLike) -> None: self.upper_bound = upper_bound def __call__(self, x: NumLike) -> ArrayLike: - return x < self.upper_bound + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.less(x, self.upper_bound) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -426,23 +428,24 @@ def feasible_like(self, prototype: NumLike) -> NumLike: def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _LessThan): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return jnp.array_equal(self.upper_bound, other.upper_bound) # type: ignore[return-value] class _LessThanEq(_LessThan): def __call__(self, x: NumLike) -> ArrayLike: - return x <= self.upper_bound + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.less_equal(x, self.upper_bound) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _LessThanEq): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return jnp.array_equal(self.upper_bound, other.upper_bound) # type: ignore[return-value] -class _IntegerInterval(Constraint): +class _IntegerInterval(Constraint[NumLike]): is_discrete = True def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: @@ -450,7 +453,12 @@ def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.upper_bound = upper_bound def __call__(self, x: NumLike) -> ArrayLike: - return (x >= self.lower_bound) & (x <= self.upper_bound) & (x % 1 == 0) + xp = jax.numpy if isinstance(x, jax.Array) else np + return ( + xp.greater_equal(x, self.lower_bound) + & xp.less_equal(x, self.upper_bound) + & xp.equal(xp.mod(x, 1), 0) + ) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -468,23 +476,24 @@ def tree_flatten(self): dict(), ) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _IntegerInterval): return False - - return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( - self.upper_bound, other.upper_bound - ) + return jnp.logical_and( + jnp.array_equal(self.lower_bound, other.lower_bound), + jnp.array_equal(self.upper_bound, other.upper_bound), + ) # type: ignore[return-value] -class _IntegerGreaterThan(Constraint): +class _IntegerGreaterThan(Constraint[NumLike]): is_discrete = True def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound def __call__(self, x: NumLike) -> ArrayLike: - return (x % 1 == 0) & (x >= self.lower_bound) + xp = jax.numpy if isinstance(x, jax.Array) else np + return (xp.mod(x, 1) == 0) & xp.greater_equal(x, self.lower_bound) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -497,29 +506,32 @@ def feasible_like(self, prototype: NumLike) -> NumLike: def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _IntegerGreaterThan): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return jnp.array_equal(self.lower_bound, other.lower_bound) # type: ignore[return-value] -class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan): +class _IntegerPositive(_SingletonConstraint[NumLike], _IntegerGreaterThan): def __init__(self) -> None: super().__init__(1) -class _IntegerNonnegative(_SingletonConstraint, _IntegerGreaterThan): +class _IntegerNonnegative(_SingletonConstraint[NumLike], _IntegerGreaterThan): def __init__(self) -> None: super().__init__(0) -class _Interval(Constraint): +class _Interval(Constraint[NumLike]): def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.lower_bound = lower_bound self.upper_bound = upper_bound def __call__(self, x: NumLike) -> ArrayLike: - return (x >= self.lower_bound) & (x <= self.upper_bound) + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.greater_equal(x, self.lower_bound) & xp.less_equal( + x, self.upper_bound + ) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -533,12 +545,12 @@ def feasible_like(self, prototype: NumLike) -> NumLike: (self.lower_bound + self.upper_bound) / 2, jax.numpy.shape(prototype) ) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _Interval): return False return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( self.upper_bound, other.upper_bound - ) + ) # type: ignore[return-value] def tree_flatten(self): return (self.lower_bound, self.upper_bound), ( @@ -547,19 +559,20 @@ def tree_flatten(self): ) -class _Circular(_SingletonConstraint, _Interval): +class _Circular(_SingletonConstraint[NumLike], _Interval): def __init__(self) -> None: super().__init__(-math.pi, math.pi) -class _UnitInterval(_SingletonConstraint, _Interval): +class _UnitInterval(_SingletonConstraint[NumLike], _Interval): def __init__(self) -> None: super().__init__(0.0, 1.0) class _OpenInterval(_Interval): def __call__(self, x: NumLike) -> ArrayLike: - return (x > self.lower_bound) & (x < self.upper_bound) + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.greater(x, self.lower_bound) & xp.less(x, self.upper_bound) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -569,16 +582,14 @@ def __repr__(self) -> str: return fmt_string -class _LowerCholesky(_SingletonConstraint): +class _LowerCholesky(_SingletonConstraint[NonScalarArray]): event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - tril = jnp.tril(x) - lower_triangular = jnp.all( - jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1 - ) - positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + tril = xp.tril(x) + lower_triangular = xp.all(xp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1) + positive_diagonal = xp.all(xp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) return lower_triangular & positive_diagonal def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: @@ -587,7 +598,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _Multinomial(Constraint): +class _Multinomial(Constraint[NonScalarArray]): is_discrete = True event_dim = 1 @@ -595,7 +606,8 @@ def __init__(self, upper_bound: ArrayLike) -> None: self.upper_bound = upper_bound def __call__(self, x: NonScalarArray) -> ArrayLike: - return (x >= 0).all(axis=-1) & (x.sum(axis=-1) == self.upper_bound) + xp = jax.numpy if isinstance(x, jax.Array) else np + return (x >= 0).all(axis=-1) & xp.equal(x.sum(axis=-1), self.upper_bound) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: pad_width = ((0, 0),) * jax.numpy.ndim(self.upper_bound) + ( @@ -607,13 +619,13 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _Multinomial): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return jnp.array_equal(self.upper_bound, other.upper_bound) # type: ignore[return-value] -class _L1Ball(_SingletonConstraint): +class _L1Ball(_SingletonConstraint[NumLike]): """ Constrain to the L1 ball of any dimension. """ @@ -622,15 +634,15 @@ class _L1Ball(_SingletonConstraint): reltol = 10.0 # Relative to finfo.eps. def __call__(self, x: NumLike) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - eps = jnp.finfo(x.dtype).eps - return jnp.abs(x).sum(axis=-1) < 1 + self.reltol * eps + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + eps = xp.finfo(x.dtype if isinstance(x, xp.ndarray) else type(x)).eps + return xp.abs(x).sum(axis=-1) < 1 + self.reltol * eps def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) -class _OrderedVector(_SingletonConstraint): +class _OrderedVector(_SingletonConstraint[NonScalarArray]): event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: @@ -642,15 +654,15 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _PositiveDefinite(_SingletonConstraint): +class _PositiveDefinite(_SingletonConstraint[NonScalarArray]): event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric - symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) + symmetric = xp.all(xp.isclose(x, xp.swapaxes(x, -2, -1)), axis=(-2, -1)) # check for the smallest eigenvalue is positive - positive = jnp.linalg.eigh(x)[0][..., 0] > 0 + positive = xp.linalg.eigh(x)[0][..., 0] > 0 return symmetric & positive def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: @@ -659,28 +671,28 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _PositiveDefiniteCirculantVector(_SingletonConstraint): +class _PositiveDefiniteCirculantVector(_SingletonConstraint[NonScalarArray]): event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - tol = 10 * jnp.finfo(x.dtype).eps - rfft = jnp.fft.rfft(x) - return (jnp.abs(rfft.imag) < tol) & (rfft.real > -tol) + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + tol = 10 * xp.finfo(x.dtype).eps + rfft = xp.fft.rfft(x) + return (xp.abs(rfft.imag) < tol) & (rfft.real > -tol) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jnp.zeros_like(prototype).at[..., 0].set(1.0) -class _PositiveSemiDefinite(_SingletonConstraint): +class _PositiveSemiDefinite(_SingletonConstraint[NonScalarArray]): event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric - symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) + symmetric = xp.all(xp.isclose(x, xp.swapaxes(x, -2, -1)), axis=(-2, -1)) # check for the smallest eigenvalue is nonnegative - nonnegative = jnp.linalg.eigh(x)[0][..., 0] >= 0 + nonnegative = xp.linalg.eigh(x)[0][..., 0] >= 0 return symmetric & nonnegative def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: @@ -689,7 +701,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _PositiveOrderedVector(_SingletonConstraint): +class _PositiveOrderedVector(_SingletonConstraint[NonScalarArray]): """ Constrains to a positive real-valued tensor where the elements are monotonically increasing along the `event_shape` dimension. @@ -698,7 +710,9 @@ class _PositiveOrderedVector(_SingletonConstraint): event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: - return ordered_vector.check(x) & independent(positive, 1).check(x) + return jnp.logical_and( + ordered_vector.check(x), independent[NumLike](positive, 1).check(x) + ) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( @@ -706,16 +720,21 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _Complex(_SingletonConstraint): +class _Complex(_SingletonConstraint[NumLike]): def __call__(self, x: NumLike) -> ArrayLike: # XXX: consider to relax this condition to [-inf, inf] interval - return (x == x) & (x != float("inf")) & (x != float("-inf")) + xp = jax.numpy if isinstance(x, jax.Array) else np + return ( + xp.equal(x, x) + & xp.not_equal(x, float("inf")) + & xp.not_equal(x, float("-inf")) + ) def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) -class _Real(_SingletonConstraint): +class _Real(_SingletonConstraint[NumLike]): def __call__(self, x: NumLike) -> ArrayLike: # XXX: consider to relax this condition to [-inf, inf] interval return (x == x) & (x != float("inf")) & (x != float("-inf")) @@ -724,7 +743,7 @@ def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) -class _Simplex(_SingletonConstraint): +class _Simplex(_SingletonConstraint[NonScalarArray]): event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: @@ -735,7 +754,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.full_like(prototype, 1 / prototype.shape[-1]) -class _SoftplusPositive(_SingletonConstraint, _GreaterThan): +class _SoftplusPositive(_SingletonConstraint[NumLike], _GreaterThan): def __init__(self) -> None: super().__init__(lower_bound=0.0) @@ -754,7 +773,7 @@ class _ScaledUnitLowerCholesky(_LowerCholesky): pass -class _Sphere(_SingletonConstraint): +class _Sphere(_SingletonConstraint[NonScalarArray]): """ Constrain to the Euclidean sphere of any dimension. """ @@ -763,31 +782,33 @@ class _Sphere(_SingletonConstraint): reltol = 10.0 # Relative to finfo.eps. def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - eps = jnp.finfo(x.dtype).eps - norm = jnp.linalg.norm(x, axis=-1) - error = jnp.abs(norm - 1) + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + eps = xp.finfo(x.dtype).eps + norm = xp.linalg.norm(x, axis=-1) + error = xp.abs(norm - 1) return error < self.reltol * eps * x.shape[-1] ** 0.5 def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.full_like(prototype, prototype.shape[-1] ** (-0.5)) -class _ZeroSum(Constraint): +class _ZeroSum(Constraint[NonScalarArray]): def __init__(self, event_dim: int = 1) -> None: self.event_dim = event_dim super().__init__() def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10 + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + tol = xp.finfo(x.dtype).eps * x.shape[-1] * 10 zerosum_true = True for dim in range(-self.event_dim, 0): - zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, atol=tol) + zerosum_true = zerosum_true & xp.allclose(x.sum(dim), 0, atol=tol) return zerosum_true - def __eq__(self, other: ConstraintT) -> bool: - return type(self) is type(other) and self.event_dim == other.event_dim + def __eq__(self, other: object) -> bool: + if not isinstance(other, _ZeroSum): + return False + return self.event_dim == other.event_dim def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.zeros_like(prototype) @@ -799,6 +820,7 @@ def tree_flatten(self): # TODO: Make types consistent # See https://github.com/pytorch/pytorch/issues/50616 + boolean: ConstraintT = _Boolean() circular: ConstraintT = _Circular() complex: ConstraintT = _Complex() diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 4512225c8..11cf179e0 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -3,7 +3,7 @@ import math -from typing import Generic, Optional, Sequence, Tuple, TypeVar, Union, cast +from typing import Generic, Optional, Sequence, Tuple, Union, cast import warnings import weakref @@ -18,7 +18,14 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, NonScalarArray, NumLike, PyTree, TransformT +from numpyro._typing import ( + ConstraintT, + NonScalarArray, + NumLike, + NumLikeT, + PyTree, + TransformT, +) from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -65,9 +72,6 @@ def _clipped_expit(x: NumLike) -> NumLike: return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps) -NumLikeT = TypeVar("NumLikeT", bound=NumLike) - - class Transform(Generic[NumLikeT]): _inv: Optional[Union[TransformT, weakref.ref]] = None diff --git a/pyproject.toml b/pyproject.toml index 47925bf7a..f3b504427 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,7 @@ module = [ "numpyro.primitives.*", "numpyro.patch.*", "numpyro.util.*", + "numpyro.distributions.constraints", "numpyro.distributions.transforms", ] ignore_errors = false