Skip to content

Commit 63277c0

Browse files
sujay-panditmathurinmQB3
authored
FEAT add WeightedQuadratic datafit to allow sample weights (#258)
Co-authored-by: mathurinm <[email protected]> Co-authored-by: QB3 <[email protected]>
1 parent ccc6344 commit 63277c0

File tree

6 files changed

+177
-4
lines changed

6 files changed

+177
-4
lines changed

doc/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Datafits
6969
Quadratic
7070
QuadraticGroup
7171
QuadraticSVC
72+
WeightedQuadratic
7273

7374

7475
Solvers

doc/changes/0.4.rst

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Version 0.4 (in progress)
44
-------------------------
55
- Add :ref:`GroupLasso Estimator <skglm.GroupLasso>` (PR: :gh:`228`)
66
- Add support and tutorial for positive coefficients to :ref:`Group Lasso Penalty <skglm.penalties.WeightedGroupL2>` (PR: :gh:`221`)
7+
- Add support to weight samples in the quadratic datafit :ref:`Weighted Quadratic Datafit <skglm.datafit.WeightedQuadratic>` (PR: :gh:`258`)
78

89

910
Version 0.3.1 (2023/12/21)

examples/plot_survival_analysis.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454
# %%
5555
# Fitting the Cox Estimator
56-
# -----------------
56+
# -------------------------
5757
#
5858
# After generating the synthetic data, we can now fit a L1-regularized Cox estimator.
5959
# Todo so, we need to combine a Cox datafit and a :math:`\ell_1` penalty

skglm/datafits/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base import BaseDatafit, BaseMultitaskDatafit
2-
from .single_task import Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox
2+
from .single_task import (Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma,
3+
Cox, WeightedQuadratic,)
34
from .multi_task import QuadraticMultiTask
45
from .group import QuadraticGroup, LogisticGroup
56

@@ -8,5 +9,5 @@
89
BaseDatafit, BaseMultitaskDatafit,
910
Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox,
1011
QuadraticMultiTask,
11-
QuadraticGroup, LogisticGroup
12+
QuadraticGroup, LogisticGroup, WeightedQuadratic
1213
]

skglm/datafits/single_task.py

+120
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,126 @@ def intercept_update_step(self, y, Xw):
119119
return np.mean(Xw - y)
120120

121121

122+
class WeightedQuadratic(BaseDatafit):
123+
r"""Weighted Quadratic datafit to handle sample weights.
124+
125+
The datafit reads:
126+
127+
.. math:: 1 / (2 xx \sum_(i=1)^(n_"samples") weights_i)
128+
\sum_(i=1)^(n_"samples") weights_i (y_i - (Xw)_i)^ 2
129+
130+
Attributes
131+
----------
132+
Xtwy : array, shape (n_features,)
133+
Pre-computed quantity used during the gradient evaluation.
134+
Equal to ``X.T @ (samples_weights * y)``.
135+
sample_weights : array, shape (n_samples,)
136+
Weights for each sample.
137+
138+
Note
139+
----
140+
The class is jit compiled at fit time using Numba compiler.
141+
This allows for faster computations.
142+
"""
143+
144+
def __init__(self, sample_weights):
145+
self.sample_weights = sample_weights
146+
147+
def get_spec(self):
148+
spec = (
149+
('Xtwy', float64[:]),
150+
('sample_weights', float64[:]),
151+
)
152+
return spec
153+
154+
def params_to_dict(self):
155+
return {'sample_weights': self.sample_weights}
156+
157+
def get_lipschitz(self, X, y):
158+
n_features = X.shape[1]
159+
lipschitz = np.zeros(n_features, dtype=X.dtype)
160+
w_sum = self.sample_weights.sum()
161+
162+
for j in range(n_features):
163+
lipschitz[j] = (self.sample_weights * X[:, j] ** 2).sum() / w_sum
164+
165+
return lipschitz
166+
167+
def get_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
168+
n_features = len(X_indptr) - 1
169+
lipschitz = np.zeros(n_features, dtype=X_data.dtype)
170+
w_sum = self.sample_weights.sum()
171+
172+
for j in range(n_features):
173+
nrm2 = 0.
174+
for idx in range(X_indptr[j], X_indptr[j + 1]):
175+
nrm2 += self.sample_weights[X_indices[idx]] * X_data[idx] ** 2
176+
177+
lipschitz[j] = nrm2 / w_sum
178+
179+
return lipschitz
180+
181+
def initialize(self, X, y):
182+
self.Xtwy = X.T @ (self.sample_weights * y)
183+
184+
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
185+
n_features = len(X_indptr) - 1
186+
self.Xty = np.zeros(n_features, dtype=X_data.dtype)
187+
188+
for j in range(n_features):
189+
xty = 0
190+
for idx in range(X_indptr[j], X_indptr[j + 1]):
191+
xty += (X_data[idx] * self.sample_weights[X_indices[idx]]
192+
* y[X_indices[idx]])
193+
self.Xty[j] = xty
194+
195+
def get_global_lipschitz(self, X, y):
196+
w_sum = self.sample_weights.sum()
197+
return norm(X.T @ np.sqrt(self.sample_weights), ord=2) ** 2 / w_sum
198+
199+
def get_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
200+
return spectral_norm(
201+
X_data * np.sqrt(self.sample_weights[X_indices]),
202+
X_indptr, X_indices, len(y)) ** 2 / self.sample_weights.sum()
203+
204+
def value(self, y, w, Xw):
205+
w_sum = self.sample_weights.sum()
206+
return np.sum(self.sample_weights * (y - Xw) ** 2) / (2 * w_sum)
207+
208+
def gradient_scalar(self, X, y, w, Xw, j):
209+
return (X[:, j] @ (self.sample_weights * (Xw - y))) / self.sample_weights.sum()
210+
211+
def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
212+
XjTXw = 0.
213+
for i in range(X_indptr[j], X_indptr[j + 1]):
214+
XjTXw += X_data[i] * self.sample_weights[X_indices[i]] * Xw[X_indices[i]]
215+
return (XjTXw - self.Xty[j]) / self.sample_weights.sum()
216+
217+
def gradient(self, X, y, Xw):
218+
return X.T @ (self.sample_weights * (Xw - y)) / self.sample_weights.sum()
219+
220+
def raw_grad(self, y, Xw):
221+
return (self.sample_weights * (Xw - y)) / self.sample_weights.sum()
222+
223+
def raw_hessian(self, y, Xw):
224+
return self.sample_weights / self.sample_weights.sum()
225+
226+
def full_grad_sparse(self, X_data, X_indptr, X_indices, y, Xw):
227+
n_features = X_indptr.shape[0] - 1
228+
grad = np.zeros(n_features, dtype=Xw.dtype)
229+
230+
for j in range(n_features):
231+
XjTXw = 0.
232+
for i in range(X_indptr[j], X_indptr[j + 1]):
233+
XjTXw += (X_data[i] * self.sample_weights[X_indices[i]]
234+
* Xw[X_indices[i]])
235+
grad[j] = (XjTXw - self.Xty[j]) / self.sample_weights.sum()
236+
return grad
237+
238+
def intercept_update_step(self, y, Xw):
239+
return np.sum(self.sample_weights * (Xw - y)) / self.sample_weights.sum()
240+
241+
122242
@njit
123243
def sigmoid(x):
124244
"""Vectorwise sigmoid."""

skglm/tests/test_datafits.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from sklearn.linear_model import HuberRegressor
66
from numpy.testing import assert_allclose, assert_array_less
77

8-
from skglm.datafits import Huber, Logistic, Poisson, Gamma, Cox
8+
from skglm.datafits import (Huber, Logistic, Poisson, Gamma, Cox, WeightedQuadratic,
9+
Quadratic,)
910
from skglm.penalties import L1, WeightedL1
1011
from skglm.solvers import AndersonCD, ProxNewton
1112
from skglm import GeneralizedLinearEstimator
@@ -169,5 +170,54 @@ def test_cox(use_efron):
169170
np.testing.assert_allclose(positive_eig, 0., atol=1e-6)
170171

171172

173+
@pytest.mark.parametrize("fit_intercept", [True, False])
174+
def test_sample_weights(fit_intercept):
175+
"""Test that integers sample weights give same result as duplicating rows."""
176+
177+
rng = np.random.RandomState(0)
178+
179+
n_samples = 20
180+
n_features = 100
181+
X, y, _ = make_correlated_data(
182+
n_samples=n_samples, n_features=n_features, random_state=0)
183+
184+
indices = rng.choice(n_samples, 3 * n_samples)
185+
186+
sample_weights = np.zeros(n_samples)
187+
for i in indices:
188+
sample_weights[i] += 1
189+
190+
X_overs, y_overs = X[indices], y[indices]
191+
192+
df_weight = WeightedQuadratic(sample_weights=sample_weights)
193+
df_overs = Quadratic()
194+
195+
# same df value
196+
w = np.random.randn(n_features)
197+
val_overs = df_overs.value(y_overs, X_overs, X_overs @ w)
198+
val_weight = df_weight.value(y, X, X @ w)
199+
np.testing.assert_allclose(val_overs, val_weight)
200+
201+
pen = L1(alpha=1)
202+
alpha_max = pen.alpha_max(df_weight.gradient(X, y, np.zeros(X.shape[0])))
203+
pen.alpha = alpha_max / 10
204+
solver = AndersonCD(tol=1e-12, verbose=10, fit_intercept=fit_intercept)
205+
206+
model_weight = GeneralizedLinearEstimator(df_weight, pen, solver)
207+
model_weight.fit(X, y)
208+
print("#" * 80)
209+
res = model_weight.coef_
210+
model = GeneralizedLinearEstimator(df_overs, pen, solver)
211+
model.fit(X_overs, y_overs)
212+
res_overs = model.coef_
213+
214+
np.testing.assert_allclose(res, res_overs)
215+
# n_iter = model.n_iter_
216+
# n_iter_overs = model.n_iter_
217+
# due to numerical errors the assert fails, but (inspecting the verbose output)
218+
# everything matches up to numerical precision errors in tol:
219+
# np.testing.assert_equal(n_iter, n_iter_overs)
220+
221+
172222
if __name__ == '__main__':
173223
pass

0 commit comments

Comments
 (0)