Skip to content

Commit c8bae79

Browse files
committed
More consistent output from link functions
1 parent e342311 commit c8bae79

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

src/glum/_link.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ class IdentityLink(Link):
7979
def __eq__(self, other): # noqa D
8080
return isinstance(other, self.__class__)
8181

82-
# unnecessary type hint for consistency with other methods
8382
def link(self, mu):
8483
"""Return mu (identity link).
8584
@@ -102,7 +101,6 @@ def derivative(self, mu):
102101
"""
103102
return 1.0 if np.isscalar(mu) else np.ones_like(mu)
104103

105-
# unnecessary type hint for consistency with other methods
106104
def inverse(self, lin_pred):
107105
"""Compute the inverse link function ``h(lin_pred)``.
108106
@@ -156,7 +154,7 @@ def link(self, mu):
156154
-------
157155
numpy.ndarray
158156
"""
159-
return np.log(mu)
157+
return np.log(_asanyarray(mu))
160158

161159
def derivative(self, mu):
162160
"""Get the derivative of the log link, one over ``mu``.
@@ -169,7 +167,7 @@ def derivative(self, mu):
169167
-------
170168
numpy.ndarray
171169
"""
172-
return 1 / mu
170+
return 1 / _asanyarray(mu)
173171

174172
def inverse(self, lin_pred):
175173
"""Get the inverse of the log link, the exponential function.
@@ -184,7 +182,7 @@ def inverse(self, lin_pred):
184182
-------
185183
numpy.ndarray
186184
"""
187-
return np.exp(lin_pred)
185+
return np.exp(_asanyarray(lin_pred))
188186

189187
def inverse_derivative(self, lin_pred):
190188
"""Compute the derivative of the inverse link function ``h'(lin_pred)``.
@@ -194,17 +192,17 @@ def inverse_derivative(self, lin_pred):
194192
lin_pred : array-like, shape (n_samples,)
195193
Usually the (fitted) linear predictor.
196194
"""
197-
return np.exp(lin_pred)
195+
return np.exp(_asanyarray(lin_pred))
198196

199197
def inverse_derivative2(self, lin_pred):
200-
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
198+
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
201199
202200
Parameters
203201
----------
204202
lin_pred : array-like, shape (n_samples,)
205203
Usually the (fitted) linear predictor.
206204
"""
207-
return np.exp(lin_pred)
205+
return np.exp(_asanyarray(lin_pred))
208206

209207

210208
class LogitLink(Link):
@@ -282,7 +280,7 @@ def inverse_derivative(self, lin_pred):
282280
return ep * (1.0 - ep)
283281

284282
def inverse_derivative2(self, lin_pred):
285-
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
283+
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
286284
287285
Parameters
288286
----------
@@ -394,7 +392,7 @@ def inverse_derivative(self, lin_pred):
394392

395393
@catch_p
396394
def inverse_derivative2(self, lin_pred):
397-
"""Compute secondnd derivative of the inverse Tweedie link function \
395+
"""Compute second derivative of the inverse Tweedie link function \
398396
``h''(lin_pred)``.
399397
400398
Parameters
@@ -484,7 +482,7 @@ def inverse_derivative(self, lin_pred):
484482
return np.exp(lin_pred - np.exp(lin_pred))
485483

486484
def inverse_derivative2(self, lin_pred):
487-
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
485+
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
488486
489487
Parameters
490488
----------

0 commit comments

Comments
 (0)