Skip to content

Commit 1284375

Browse files
authored
More consistent output from link functions (#762)
1 parent 3ba4209 commit 1284375

File tree

2 files changed

+35
-32
lines changed

2 files changed

+35
-32
lines changed

src/glum/_link.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ class IdentityLink(Link):
7979
def __eq__(self, other): # noqa D
8080
return isinstance(other, self.__class__)
8181

82-
# unnecessary type hint for consistency with other methods
8382
def link(self, mu):
8483
"""Return mu (identity link).
8584
@@ -102,7 +101,6 @@ def derivative(self, mu):
102101
"""
103102
return 1.0 if np.isscalar(mu) else np.ones_like(mu)
104103

105-
# unnecessary type hint for consistency with other methods
106104
def inverse(self, lin_pred):
107105
"""Compute the inverse link function ``h(lin_pred)``.
108106
@@ -156,7 +154,7 @@ def link(self, mu):
156154
-------
157155
numpy.ndarray
158156
"""
159-
return np.log(mu)
157+
return np.log(_asanyarray(mu))
160158

161159
def derivative(self, mu):
162160
"""Get the derivative of the log link, one over ``mu``.
@@ -169,7 +167,7 @@ def derivative(self, mu):
169167
-------
170168
numpy.ndarray
171169
"""
172-
return 1 / mu
170+
return 1 / _asanyarray(mu)
173171

174172
def inverse(self, lin_pred):
175173
"""Get the inverse of the log link, the exponential function.
@@ -184,7 +182,7 @@ def inverse(self, lin_pred):
184182
-------
185183
numpy.ndarray
186184
"""
187-
return np.exp(lin_pred)
185+
return np.exp(_asanyarray(lin_pred))
188186

189187
def inverse_derivative(self, lin_pred):
190188
"""Compute the derivative of the inverse link function ``h'(lin_pred)``.
@@ -194,17 +192,17 @@ def inverse_derivative(self, lin_pred):
194192
lin_pred : array-like, shape (n_samples,)
195193
Usually the (fitted) linear predictor.
196194
"""
197-
return np.exp(lin_pred)
195+
return np.exp(_asanyarray(lin_pred))
198196

199197
def inverse_derivative2(self, lin_pred):
200-
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
198+
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
201199
202200
Parameters
203201
----------
204202
lin_pred : array-like, shape (n_samples,)
205203
Usually the (fitted) linear predictor.
206204
"""
207-
return np.exp(lin_pred)
205+
return np.exp(_asanyarray(lin_pred))
208206

209207

210208
class LogitLink(Link):
@@ -282,7 +280,7 @@ def inverse_derivative(self, lin_pred):
282280
return ep * (1.0 - ep)
283281

284282
def inverse_derivative2(self, lin_pred):
285-
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
283+
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
286284
287285
Parameters
288286
----------
@@ -324,61 +322,56 @@ def _to_return(*args, **kwargs):
324322
class TweedieLink(Link):
325323
"""The Tweedie link function ``x^(1-p)`` if ``p≠1`` and ``log(x)`` if ``p=1``."""
326324

327-
def __new__(cls, p: float):
328-
"""
329-
Create a new ``TweedieLink`` object.
330-
331-
Parameters
332-
----------
333-
p: scalar
334-
"""
335-
if p == 0:
336-
return IdentityLink()
337-
if p == 1:
338-
return LogLink()
339-
return super().__new__(cls)
340-
341325
def __init__(self, p):
342326
self.p = p
343327

344328
def __eq__(self, other): # noqa D
345329
return isinstance(other, self.__class__) and (self.p == other.p)
346330

347331
def link(self, mu):
348-
"""
349-
Get the Tweedie canonical link.
332+
"""Get the Tweedie canonical link.
350333
351334
See superclass documentation.
352335
353336
Parameters
354337
----------
355338
mu: array-like
356339
"""
340+
if self.p == 0:
341+
return _asanyarray(mu)
342+
if self.p == 1:
343+
return np.log(_asanyarray(mu))
357344
return _asanyarray(mu) ** (1 - self.p)
358345

359346
def derivative(self, mu):
360-
"""
361-
Get the derivative of the Tweedie link.
347+
"""Get the derivative of the Tweedie link.
362348
363349
See superclass documentation.
364350
365351
Parameters
366352
----------
367353
mu: array-like
368354
"""
355+
if self.p == 0:
356+
return 1.0 if np.isscalar(mu) else np.ones_like(mu)
357+
if self.p == 1:
358+
return 1 / _asanyarray(mu)
369359
return (1 - self.p) * _asanyarray(mu) ** (-self.p)
370360

371361
@catch_p
372362
def inverse(self, lin_pred):
373-
"""
374-
Get the inverse of the Tweedie link.
363+
"""Get the inverse of the Tweedie link.
375364
376365
See superclass documentation.
377366
378367
Parameters
379368
----------
380369
mu: array-like
381370
"""
371+
if self.p == 0:
372+
return _asanyarray(lin_pred)
373+
if self.p == 1:
374+
return np.exp(_asanyarray(lin_pred))
382375
return _asanyarray(lin_pred) ** (1 / (1 - self.p))
383376

384377
@catch_p
@@ -390,20 +383,30 @@ def inverse_derivative(self, lin_pred):
390383
lin_pred : array-like, shape (n_samples,)
391384
Usually the (fitted) linear predictor.
392385
"""
386+
if self.p == 0:
387+
return 1.0 if np.isscalar(lin_pred) else np.ones_like(lin_pred)
388+
if self.p == 1:
389+
return np.exp(_asanyarray(lin_pred))
393390
return (1 / (1 - self.p)) * _asanyarray(lin_pred) ** (self.p / (1 - self.p))
394391

395392
@catch_p
396393
def inverse_derivative2(self, lin_pred):
397-
"""Compute secondnd derivative of the inverse Tweedie link function \
394+
"""Compute second derivative of the inverse Tweedie link function \
398395
``h''(lin_pred)``.
399396
400397
Parameters
401398
----------
402399
lin_pred : array, shape (n_samples,)
403400
Usually the (fitted) linear predictor.
404401
"""
402+
if self.p == 0:
403+
return 0.0 if np.isscalar(lin_pred) else np.zeros_like(lin_pred)
404+
if self.p == 1:
405+
return np.exp(_asanyarray(lin_pred))
406+
405407
result = _asanyarray(lin_pred) ** ((2 * self.p - 1) / (1 - self.p))
406408
result *= self.p / (1 - self.p) ** 2
409+
407410
return result
408411

409412

@@ -484,7 +487,7 @@ def inverse_derivative(self, lin_pred):
484487
return np.exp(lin_pred - np.exp(lin_pred))
485488

486489
def inverse_derivative2(self, lin_pred):
487-
"""Compute 2nd derivative of the inverse link function ``h''(lin_pred)``.
490+
"""Compute second derivative of the inverse link function ``h''(lin_pred)``.
488491
489492
Parameters
490493
----------

tests/glm/test_link.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_link_properties(link):
3535

3636
def test_equality():
3737
assert TweedieLink(1.5) == TweedieLink(1.5)
38-
assert TweedieLink(1) == LogLink()
38+
assert TweedieLink(1) != LogLink()
3939
assert LogLink() == LogLink()
4040
assert TweedieLink(1.5) != TweedieLink(2.5)
4141
assert TweedieLink(1.5) != LogitLink()

0 commit comments

Comments
 (0)