Skip to content

Commit 9c37cd7

Browse files
authored
MNT add tags for sklearn (#293)
1 parent 1799819 commit 9c37cd7

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

.github/workflows/flake8.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ jobs:
3030
- name: Check doc style with pydocstyle
3131
run: |
3232
pip install pydocstyle
33-
pydocstyle skglm --ignore='D100',D102,'D104','D107','D203','D213','D413'
33+
pydocstyle skglm --ignore='D100',D102,'D104','D105','D107','D203','D213','D413',

skglm/estimators.py

+25
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,11 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
448448
warm_start=self.warm_start, verbose=self.verbose)
449449
return solver.path(X, y, datafit, penalty, alphas, coef_init, return_n_iter)
450450

451+
def __sklearn_tags__(self):
452+
tags = super().__sklearn_tags__()
453+
tags.input_tags.sparse = True
454+
return tags
455+
451456

452457
class WeightedLasso(RegressorMixin, LinearModel):
453458
r"""WeightedLasso estimator based on Celer solver and primal extrapolation.
@@ -611,6 +616,11 @@ def fit(self, X, y):
611616
warm_start=self.warm_start, verbose=self.verbose)
612617
return _glm_fit(X, y, self, Quadratic(), penalty, solver)
613618

619+
def __sklearn_tags__(self):
620+
tags = super().__sklearn_tags__()
621+
tags.input_tags.sparse = True
622+
return tags
623+
614624

615625
class ElasticNet(RegressorMixin, LinearModel):
616626
r"""Elastic net estimator.
@@ -765,6 +775,11 @@ def fit(self, X, y):
765775
return _glm_fit(X, y, self, Quadratic(),
766776
L1_plus_L2(self.alpha, self.l1_ratio, self.positive), solver)
767777

778+
def __sklearn_tags__(self):
779+
tags = super().__sklearn_tags__()
780+
tags.input_tags.sparse = True
781+
return tags
782+
768783

769784
class MCPRegression(RegressorMixin, LinearModel):
770785
r"""Linear regression with MCP penalty estimator.
@@ -953,6 +968,11 @@ def fit(self, X, y):
953968
warm_start=self.warm_start, verbose=self.verbose)
954969
return _glm_fit(X, y, self, Quadratic(), penalty, solver)
955970

971+
def __sklearn_tags__(self):
972+
tags = super().__sklearn_tags__()
973+
tags.input_tags.sparse = True
974+
return tags
975+
956976

957977
class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator):
958978
r"""Sparse Logistic regression estimator.
@@ -1380,6 +1400,11 @@ def fit(self, X, y):
13801400

13811401
return self
13821402

1403+
def __sklearn_tags__(self):
1404+
tags = super().__sklearn_tags__()
1405+
tags.input_tags.sparse = True
1406+
return tags
1407+
13831408

13841409
class MultiTaskLasso(RegressorMixin, LinearModel):
13851410
r"""MultiTaskLasso estimator.

0 commit comments

Comments
 (0)