diff --git a/src/glum/_link.py b/src/glum/_link.py index 76370563..8a19b731 100644 --- a/src/glum/_link.py +++ b/src/glum/_link.py @@ -79,7 +79,6 @@ class IdentityLink(Link): def __eq__(self, other): # noqa D return isinstance(other, self.__class__) - # unnecessary type hint for consistency with other methods def link(self, mu): """Return mu (identity link). @@ -102,7 +101,6 @@ def derivative(self, mu): """ return 1.0 if np.isscalar(mu) else np.ones_like(mu) - # unnecessary type hint for consistency with other methods def inverse(self, lin_pred): """Compute the inverse link function ``h(lin_pred)``. @@ -156,7 +154,7 @@ def link(self, mu): ------- numpy.ndarray """ - return np.log(mu) + return np.log(_asanyarray(mu)) def derivative(self, mu): """Get the derivative of the log link, one over ``mu``. @@ -169,7 +167,7 @@ def derivative(self, mu): ------- numpy.ndarray """ - return 1 / mu + return 1 / _asanyarray(mu) def inverse(self, lin_pred): """Get the inverse of the log link, the exponential function. @@ -184,7 +182,7 @@ def inverse(self, lin_pred): ------- numpy.ndarray """ - return np.exp(lin_pred) + return np.exp(_asanyarray(lin_pred)) def inverse_derivative(self, lin_pred): """Compute the derivative of the inverse link function ``h'(lin_pred)``. @@ -194,17 +192,17 @@ def inverse_derivative(self, lin_pred): lin_pred : array-like, shape (n_samples,) Usually the (fitted) linear predictor. """ - return np.exp(lin_pred) + return np.exp(_asanyarray(lin_pred)) def inverse_derivative2(self, lin_pred): - """Compute 2nd derivative of the inverse link function ``h''(lin_pred)``. + """Compute second derivative of the inverse link function ``h''(lin_pred)``. Parameters ---------- lin_pred : array-like, shape (n_samples,) Usually the (fitted) linear predictor. """ - return np.exp(lin_pred) + return np.exp(_asanyarray(lin_pred)) class LogitLink(Link): @@ -282,7 +280,7 @@ def inverse_derivative(self, lin_pred): return ep * (1.0 - ep) def inverse_derivative2(self, lin_pred): - """Compute 2nd derivative of the inverse link function ``h''(lin_pred)``. + """Compute second derivative of the inverse link function ``h''(lin_pred)``. Parameters ---------- @@ -324,20 +322,6 @@ def _to_return(*args, **kwargs): class TweedieLink(Link): """The Tweedie link function ``x^(1-p)`` if ``p≠1`` and ``log(x)`` if ``p=1``.""" - def __new__(cls, p: float): - """ - Create a new ``TweedieLink`` object. - - Parameters - ---------- - p: scalar - """ - if p == 0: - return IdentityLink() - if p == 1: - return LogLink() - return super().__new__(cls) - def __init__(self, p): self.p = p @@ -345,8 +329,7 @@ def __eq__(self, other): # noqa D return isinstance(other, self.__class__) and (self.p == other.p) def link(self, mu): - """ - Get the Tweedie canonical link. + """Get the Tweedie canonical link. See superclass documentation. @@ -354,11 +337,14 @@ def link(self, mu): ---------- mu: array-like """ + if self.p == 0: + return _asanyarray(mu) + if self.p == 1: + return np.log(_asanyarray(mu)) return _asanyarray(mu) ** (1 - self.p) def derivative(self, mu): - """ - Get the derivative of the Tweedie link. + """Get the derivative of the Tweedie link. See superclass documentation. @@ -366,12 +352,15 @@ def derivative(self, mu): ---------- mu: array-like """ + if self.p == 0: + return 1.0 if np.isscalar(mu) else np.ones_like(mu) + if self.p == 1: + return 1 / _asanyarray(mu) return (1 - self.p) * _asanyarray(mu) ** (-self.p) @catch_p def inverse(self, lin_pred): - """ - Get the inverse of the Tweedie link. + """Get the inverse of the Tweedie link. See superclass documentation. @@ -379,6 +368,10 @@ def inverse(self, lin_pred): ---------- mu: array-like """ + if self.p == 0: + return _asanyarray(lin_pred) + if self.p == 1: + return np.exp(_asanyarray(lin_pred)) return _asanyarray(lin_pred) ** (1 / (1 - self.p)) @catch_p @@ -390,11 +383,15 @@ def inverse_derivative(self, lin_pred): lin_pred : array-like, shape (n_samples,) Usually the (fitted) linear predictor. """ + if self.p == 0: + return 1.0 if np.isscalar(lin_pred) else np.ones_like(lin_pred) + if self.p == 1: + return np.exp(_asanyarray(lin_pred)) return (1 / (1 - self.p)) * _asanyarray(lin_pred) ** (self.p / (1 - self.p)) @catch_p def inverse_derivative2(self, lin_pred): - """Compute secondnd derivative of the inverse Tweedie link function \ + """Compute second derivative of the inverse Tweedie link function \ ``h''(lin_pred)``. Parameters @@ -402,8 +399,14 @@ def inverse_derivative2(self, lin_pred): lin_pred : array, shape (n_samples,) Usually the (fitted) linear predictor. """ + if self.p == 0: + return 0.0 if np.isscalar(lin_pred) else np.zeros_like(lin_pred) + if self.p == 1: + return np.exp(_asanyarray(lin_pred)) + result = _asanyarray(lin_pred) ** ((2 * self.p - 1) / (1 - self.p)) result *= self.p / (1 - self.p) ** 2 + return result @@ -484,7 +487,7 @@ def inverse_derivative(self, lin_pred): return np.exp(lin_pred - np.exp(lin_pred)) def inverse_derivative2(self, lin_pred): - """Compute 2nd derivative of the inverse link function ``h''(lin_pred)``. + """Compute second derivative of the inverse link function ``h''(lin_pred)``. Parameters ---------- diff --git a/tests/glm/test_link.py b/tests/glm/test_link.py index 12e581b7..02052afb 100644 --- a/tests/glm/test_link.py +++ b/tests/glm/test_link.py @@ -35,7 +35,7 @@ def test_link_properties(link): def test_equality(): assert TweedieLink(1.5) == TweedieLink(1.5) - assert TweedieLink(1) == LogLink() + assert TweedieLink(1) != LogLink() assert LogLink() == LogLink() assert TweedieLink(1.5) != TweedieLink(2.5) assert TweedieLink(1.5) != LogitLink()