Skip to content

Commit

Permalink
More consistent output from link functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lbittarello committed Feb 12, 2024
1 parent e342311 commit c8bae79
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions src/glum/_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class IdentityLink(Link):
def __eq__(self, other): # noqa D
return isinstance(other, self.__class__)

# unnecessary type hint for consistency with other methods
def link(self, mu):
"""Return mu (identity link).
Expand All @@ -102,7 +101,6 @@ def derivative(self, mu):
"""
return 1.0 if np.isscalar(mu) else np.ones_like(mu)

# unnecessary type hint for consistency with other methods
def inverse(self, lin_pred):
"""Compute the inverse link function ``h(lin_pred)``.
Expand Down Expand Up @@ -156,7 +154,7 @@ def link(self, mu):
-------
numpy.ndarray
"""
return np.log(mu)
return np.log(_asanyarray(mu))

def derivative(self, mu):
"""Get the derivative of the log link, one over ``mu``.
Expand All @@ -169,7 +167,7 @@ def derivative(self, mu):
-------
numpy.ndarray
"""
return 1 / mu
return 1 / _asanyarray(mu)

def inverse(self, lin_pred):
"""Get the inverse of the log link, the exponential function.
Expand All @@ -184,7 +182,7 @@ def inverse(self, lin_pred):
-------
numpy.ndarray
"""
return np.exp(lin_pred)
return np.exp(_asanyarray(lin_pred))

def inverse_derivative(self, lin_pred):
"""Compute the derivative of the inverse link function ``h'(lin_pred)``.
Expand All @@ -194,17 +192,17 @@ def inverse_derivative(self, lin_pred):
lin_pred : array-like, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
return np.exp(lin_pred)
return np.exp(_asanyarray(lin_pred))

def inverse_derivative2(self, lin_pred):
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
Parameters
----------
lin_pred : array-like, shape (n_samples,)
Usually the (fitted) linear predictor.
"""
return np.exp(lin_pred)
return np.exp(_asanyarray(lin_pred))


class LogitLink(Link):
Expand Down Expand Up @@ -282,7 +280,7 @@ def inverse_derivative(self, lin_pred):
return ep * (1.0 - ep)

def inverse_derivative2(self, lin_pred):
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
Parameters
----------
Expand Down Expand Up @@ -394,7 +392,7 @@ def inverse_derivative(self, lin_pred):

@catch_p
def inverse_derivative2(self, lin_pred):
"""Compute secondnd derivative of the inverse Tweedie link function \
"""Compute second derivative of the inverse Tweedie link function \
``h''(lin_pred)``.
Parameters
Expand Down Expand Up @@ -484,7 +482,7 @@ def inverse_derivative(self, lin_pred):
return np.exp(lin_pred - np.exp(lin_pred))

def inverse_derivative2(self, lin_pred):
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
Parameters
----------
Expand Down

0 comments on commit c8bae79

Please sign in to comment.