Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More consistent output from link functions #762

Merged
merged 4 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 34 additions & 31 deletions src/glum/_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class IdentityLink(Link):
def __eq__(self, other): # noqa D
return isinstance(other, self.__class__)

# unnecessary type hint for consistency with other methods
def link(self, mu):
"""Return mu (identity link).

Expand All @@ -102,7 +101,6 @@ def derivative(self, mu):
"""
return 1.0 if np.isscalar(mu) else np.ones_like(mu)

# unnecessary type hint for consistency with other methods
def inverse(self, lin_pred):
"""Compute the inverse link function ``h(lin_pred)``.

Expand Down Expand Up @@ -156,7 +154,7 @@ def link(self, mu):
-------
numpy.ndarray
"""
return np.log(mu)
return np.log(_asanyarray(mu))

def derivative(self, mu):
"""Get the derivative of the log link, one over ``mu``.
Expand All @@ -169,7 +167,7 @@ def derivative(self, mu):
-------
numpy.ndarray
"""
return 1 / mu
return 1 / _asanyarray(mu)

def inverse(self, lin_pred):
"""Get the inverse of the log link, the exponential function.
Expand All @@ -184,7 +182,7 @@ def inverse(self, lin_pred):
-------
numpy.ndarray
"""
return np.exp(lin_pred)
return np.exp(_asanyarray(lin_pred))

def inverse_derivative(self, lin_pred):
"""Compute the derivative of the inverse link function ``h'(lin_pred)``.
Expand All @@ -194,17 +192,17 @@ def inverse_derivative(self, lin_pred):
lin_pred : array-like, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
return np.exp(lin_pred)
return np.exp(_asanyarray(lin_pred))

def inverse_derivative2(self, lin_pred):
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.

Parameters
----------
lin_pred : array-like, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
return np.exp(lin_pred)
return np.exp(_asanyarray(lin_pred))


class LogitLink(Link):
Expand Down Expand Up @@ -282,7 +280,7 @@ def inverse_derivative(self, lin_pred):
return ep * (1.0 - ep)

def inverse_derivative2(self, lin_pred):
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.

Parameters
----------
Expand Down Expand Up @@ -324,61 +322,56 @@ def _to_return(*args, **kwargs):
class TweedieLink(Link):
"""The Tweedie link function ``x^(1-p)`` if ``p≠1`` and ``log(x)`` if ``p=1``."""

def __new__(cls, p: float):
"""
Create a new ``TweedieLink`` object.

Parameters
----------
p: scalar
"""
if p == 0:
return IdentityLink()
if p == 1:
return LogLink()
return super().__new__(cls)

def __init__(self, p):
self.p = p

def __eq__(self, other): # noqa D
return isinstance(other, self.__class__) and (self.p == other.p)

def link(self, mu):
"""
Get the Tweedie canonical link.
"""Get the Tweedie canonical link.

See superclass documentation.

Parameters
----------
mu: array-like
"""
if self.p == 0:
return _asanyarray(mu)
if self.p == 1:
return np.log(_asanyarray(mu))
return _asanyarray(mu) ** (1 - self.p)

def derivative(self, mu):
"""
Get the derivative of the Tweedie link.
"""Get the derivative of the Tweedie link.

See superclass documentation.

Parameters
----------
mu: array-like
"""
if self.p == 0:
return 1.0 if np.isscalar(mu) else np.ones_like(mu)
if self.p == 1:
return 1 / _asanyarray(mu)
return (1 - self.p) * _asanyarray(mu) ** (-self.p)

@catch_p
def inverse(self, lin_pred):
"""
Get the inverse of the Tweedie link.
"""Get the inverse of the Tweedie link.

See superclass documentation.

Parameters
----------
mu: array-like
"""
if self.p == 0:
return _asanyarray(lin_pred)
if self.p == 1:
return np.exp(_asanyarray(lin_pred))
return _asanyarray(lin_pred) ** (1 / (1 - self.p))

@catch_p
Expand All @@ -390,20 +383,30 @@ def inverse_derivative(self, lin_pred):
lin_pred : array-like, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
if self.p == 0:
return 1.0 if np.isscalar(lin_pred) else np.ones_like(lin_pred)
if self.p == 1:
return np.exp(_asanyarray(lin_pred))
return (1 / (1 - self.p)) * _asanyarray(lin_pred) ** (self.p / (1 - self.p))

@catch_p
def inverse_derivative2(self, lin_pred):
"""Compute secondnd derivative of the inverse Tweedie link function \
"""Compute second derivative of the inverse Tweedie link function \
``h''(lin_pred)``.

Parameters
----------
lin_pred : array, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
if self.p == 0:
return 0.0 if np.isscalar(lin_pred) else np.zeros_like(lin_pred)
if self.p == 1:
return np.exp(_asanyarray(lin_pred))

result = _asanyarray(lin_pred) ** ((2 * self.p - 1) / (1 - self.p))
result *= self.p / (1 - self.p) ** 2

return result


Expand Down Expand Up @@ -484,7 +487,7 @@ def inverse_derivative(self, lin_pred):
return np.exp(lin_pred - np.exp(lin_pred))

def inverse_derivative2(self, lin_pred):
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.

Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion tests/glm/test_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_link_properties(link):

def test_equality():
assert TweedieLink(1.5) == TweedieLink(1.5)
assert TweedieLink(1) == LogLink()
assert TweedieLink(1) != LogLink()
assert LogLink() == LogLink()
assert TweedieLink(1.5) != TweedieLink(2.5)
assert TweedieLink(1.5) != LogitLink()
Expand Down