Skip to content

Commit

Permalink
More consistent output from link functions (#762)
Browse files Browse the repository at this point in the history
  • Loading branch information
lbittarello authored Feb 21, 2024
1 parent 3ba4209 commit 1284375
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 32 deletions.
65 changes: 34 additions & 31 deletions src/glum/_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class IdentityLink(Link):
def __eq__(self, other): # noqa D
return isinstance(other, self.__class__)

# unnecessary type hint for consistency with other methods
def link(self, mu):
"""Return mu (identity link).
Expand All @@ -102,7 +101,6 @@ def derivative(self, mu):
"""
return 1.0 if np.isscalar(mu) else np.ones_like(mu)

# unnecessary type hint for consistency with other methods
def inverse(self, lin_pred):
"""Compute the inverse link function ``h(lin_pred)``.
Expand Down Expand Up @@ -156,7 +154,7 @@ def link(self, mu):
-------
numpy.ndarray
"""
return np.log(mu)
return np.log(_asanyarray(mu))

def derivative(self, mu):
"""Get the derivative of the log link, one over ``mu``.
Expand All @@ -169,7 +167,7 @@ def derivative(self, mu):
-------
numpy.ndarray
"""
return 1 / mu
return 1 / _asanyarray(mu)

def inverse(self, lin_pred):
"""Get the inverse of the log link, the exponential function.
Expand All @@ -184,7 +182,7 @@ def inverse(self, lin_pred):
-------
numpy.ndarray
"""
return np.exp(lin_pred)
return np.exp(_asanyarray(lin_pred))

def inverse_derivative(self, lin_pred):
"""Compute the derivative of the inverse link function ``h'(lin_pred)``.
Expand All @@ -194,17 +192,17 @@ def inverse_derivative(self, lin_pred):
lin_pred : array-like, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
return np.exp(lin_pred)
return np.exp(_asanyarray(lin_pred))

def inverse_derivative2(self, lin_pred):
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
Parameters
----------
lin_pred : array-like, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
return np.exp(lin_pred)
return np.exp(_asanyarray(lin_pred))


class LogitLink(Link):
Expand Down Expand Up @@ -282,7 +280,7 @@ def inverse_derivative(self, lin_pred):
return ep * (1.0 - ep)

def inverse_derivative2(self, lin_pred):
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
Parameters
----------
Expand Down Expand Up @@ -324,61 +322,56 @@ def _to_return(*args, **kwargs):
class TweedieLink(Link):
"""The Tweedie link function ``x^(1-p)`` if ``p≠1`` and ``log(x)`` if ``p=1``."""

def __new__(cls, p: float):
"""
Create a new ``TweedieLink`` object.
Parameters
----------
p: scalar
"""
if p == 0:
return IdentityLink()
if p == 1:
return LogLink()
return super().__new__(cls)

def __init__(self, p):
self.p = p

def __eq__(self, other): # noqa D
return isinstance(other, self.__class__) and (self.p == other.p)

def link(self, mu):
"""
Get the Tweedie canonical link.
"""Get the Tweedie canonical link.
See superclass documentation.
Parameters
----------
mu: array-like
"""
if self.p == 0:
return _asanyarray(mu)
if self.p == 1:
return np.log(_asanyarray(mu))
return _asanyarray(mu) ** (1 - self.p)

def derivative(self, mu):
"""
Get the derivative of the Tweedie link.
"""Get the derivative of the Tweedie link.
See superclass documentation.
Parameters
----------
mu: array-like
"""
if self.p == 0:
return 1.0 if np.isscalar(mu) else np.ones_like(mu)
if self.p == 1:
return 1 / _asanyarray(mu)
return (1 - self.p) * _asanyarray(mu) ** (-self.p)

@catch_p
def inverse(self, lin_pred):
"""
Get the inverse of the Tweedie link.
"""Get the inverse of the Tweedie link.
See superclass documentation.
Parameters
----------
mu: array-like
"""
if self.p == 0:
return _asanyarray(lin_pred)
if self.p == 1:
return np.exp(_asanyarray(lin_pred))
return _asanyarray(lin_pred) ** (1 / (1 - self.p))

@catch_p
Expand All @@ -390,20 +383,30 @@ def inverse_derivative(self, lin_pred):
lin_pred : array-like, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
if self.p == 0:
return 1.0 if np.isscalar(lin_pred) else np.ones_like(lin_pred)
if self.p == 1:
return np.exp(_asanyarray(lin_pred))
return (1 / (1 - self.p)) * _asanyarray(lin_pred) ** (self.p / (1 - self.p))

@catch_p
def inverse_derivative2(self, lin_pred):
"""Compute secondnd derivative of the inverse Tweedie link function \
"""Compute second derivative of the inverse Tweedie link function \
``h''(lin_pred)``.
Parameters
----------
lin_pred : array, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
if self.p == 0:
return 0.0 if np.isscalar(lin_pred) else np.zeros_like(lin_pred)
if self.p == 1:
return np.exp(_asanyarray(lin_pred))

result = _asanyarray(lin_pred) ** ((2 * self.p - 1) / (1 - self.p))
result *= self.p / (1 - self.p) ** 2

return result


Expand Down Expand Up @@ -484,7 +487,7 @@ def inverse_derivative(self, lin_pred):
return np.exp(lin_pred - np.exp(lin_pred))

def inverse_derivative2(self, lin_pred):
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion tests/glm/test_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_link_properties(link):

def test_equality():
assert TweedieLink(1.5) == TweedieLink(1.5)
assert TweedieLink(1) == LogLink()
assert TweedieLink(1) != LogLink()
assert LogLink() == LogLink()
assert TweedieLink(1.5) != TweedieLink(2.5)
assert TweedieLink(1.5) != LogitLink()
Expand Down

0 comments on commit 1284375

Please sign in to comment.