Skip to content

Commit

Permalink
Remove TweedieLink.__new__
Browse files Browse the repository at this point in the history
  • Loading branch information
lbittarello committed Feb 14, 2024
1 parent 2c5297b commit 58b1066
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions src/glum/_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,19 +322,6 @@ def _to_return(*args, **kwargs):
class TweedieLink(Link):
"""The Tweedie link function ``x^(1-p)`` if ``p≠1`` and ``log(x)`` if ``p=1``."""

def __new__(cls, p: float):
"""Create a new ``TweedieLink`` object.
Parameters
----------
p: scalar
"""
if p == 0:
return IdentityLink()
if p == 1:
return LogLink()
return super().__new__(cls)

def __init__(self, p):
self.p = p

Expand All @@ -350,6 +337,10 @@ def link(self, mu):
----------
mu: array-like
"""
if self.p == 0:
return _asanyarray(mu)
if self.p == 1:
return np.log(_asanyarray(mu))
return _asanyarray(mu) ** (1 - self.p)

def derivative(self, mu):
Expand All @@ -361,6 +352,10 @@ def derivative(self, mu):
----------
mu: array-like
"""
if self.p == 0:
return 1.0 if np.isscalar(mu) else np.ones_like(mu)
if self.p == 1:
return 1 / _asanyarray(mu)
return (1 - self.p) * _asanyarray(mu) ** (-self.p)

@catch_p
Expand All @@ -373,6 +368,10 @@ def inverse(self, lin_pred):
----------
mu: array-like
"""
if self.p == 0:
return _asanyarray(lin_pred)
if self.p == 1:
return np.exp(_asanyarray(lin_pred))
return _asanyarray(lin_pred) ** (1 / (1 - self.p))

@catch_p
Expand All @@ -384,6 +383,10 @@ def inverse_derivative(self, lin_pred):
lin_pred : array-like, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
if self.p == 0:
return 1.0 if np.isscalar(lin_pred) else np.ones_like(lin_pred)
if self.p == 1:
return np.exp(_asanyarray(lin_pred))
return (1 / (1 - self.p)) * _asanyarray(lin_pred) ** (self.p / (1 - self.p))

@catch_p
Expand All @@ -396,8 +399,14 @@ def inverse_derivative2(self, lin_pred):
lin_pred : array, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
if self.p == 0:
return 0.0 if np.isscalar(lin_pred) else np.zeros_like(lin_pred)
if self.p == 1:
return np.exp(_asanyarray(lin_pred))

result = _asanyarray(lin_pred) ** ((2 * self.p - 1) / (1 - self.p))
result *= self.p / (1 - self.p) ** 2

return result


Expand Down

0 comments on commit 58b1066

Please sign in to comment.