Skip to content

Commit d9807a8

Browse files
committed
extend external predictions benchmarking to multiple repetitions
1 parent c0cbf41 commit d9807a8

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

doubleml/tests/test_sensitivity.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import pytest
2-
import math
32
import numpy as np
43
import copy
54

65
import doubleml as dml
76
from doubleml import DoubleMLIRM, DoubleMLData
87
from doubleml.datasets import make_irm_data
9-
from sklearn.linear_model import LinearRegression
10-
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
8+
from sklearn.linear_model import LinearRegression, LogisticRegression
119

1210
from ._utils_doubleml_sensitivity_manual import doubleml_sensitivity_manual, \
1311
doubleml_sensitivity_benchmark_manual
@@ -114,18 +112,19 @@ def test_dml_sensitivity_benchmark(dml_sensitivity_multitreat_fixture):
114112
@pytest.fixture(scope="module")
115113
def test_dml_benchmark_fixture(benchmarking_set, n_rep):
116114
random_state = 42
117-
x, y, d = make_irm_data(n_obs=10, dim_x=5, theta=0.5, return_type="np.array")
115+
x, y, d = make_irm_data(n_obs=50, dim_x=5, theta=0, return_type="np.array")
118116

119-
classifier_class = RandomForestClassifier
120-
regressor_class = RandomForestRegressor
117+
classifier_class = LogisticRegression
118+
regressor_class = LinearRegression
121119

122120
np.random.seed(3141)
123121
dml_data = DoubleMLData.from_arrays(x=x, y=y, d=d)
124122
x_list_long = copy.deepcopy(dml_data.x_cols)
125123
dml_int = DoubleMLIRM(dml_data,
126124
ml_m=classifier_class(random_state=random_state),
127-
ml_g=regressor_class(random_state=random_state),
128-
n_folds=2)
125+
ml_g=regressor_class(),
126+
n_folds=2,
127+
n_rep=n_rep)
129128
dml_int.fit(store_predictions=True)
130129
dml_int.sensitivity_analysis()
131130
dml_ext = copy.deepcopy(dml_int)
@@ -136,8 +135,9 @@ def test_dml_benchmark_fixture(benchmarking_set, n_rep):
136135
dml_data_short.x_cols = [x for x in x_list_long if x not in benchmarking_set]
137136
dml_short = DoubleMLIRM(dml_data_short,
138137
ml_m=classifier_class(random_state=random_state),
139-
ml_g=regressor_class(random_state=random_state),
140-
n_folds=2)
138+
ml_g=regressor_class(),
139+
n_folds=2,
140+
n_rep=n_rep)
141141
dml_short.fit(store_predictions=True)
142142
fit_args = {"external_predictions": {"d": {"ml_m": dml_short.predictions["ml_m"][:, :, 0],
143143
"ml_g0": dml_short.predictions["ml_g0"][:, :, 0],
@@ -148,15 +148,15 @@ def test_dml_benchmark_fixture(benchmarking_set, n_rep):
148148
dml_ext.sensitivity_analysis()
149149
df_bm_ext = dml_ext.sensitivity_benchmark(benchmarking_set=benchmarking_set, fit_args=fit_args)
150150

151-
res_dict = {"default_benchmark": df_bm.loc["d", "delta_theta"],
152-
"external_benchmark": df_bm_ext.loc["d", "delta_theta"]}
151+
res_dict = {"default_benchmark": df_bm,
152+
"external_benchmark": df_bm_ext}
153153

154154
return res_dict
155155

156156

157157
@pytest.mark.ci
158158
def test_dml_sensitivity_external_predictions(test_dml_benchmark_fixture):
159-
assert math.isclose(test_dml_benchmark_fixture["default_benchmark"],
160-
test_dml_benchmark_fixture["external_benchmark"],
161-
rel_tol=1e-9,
162-
abs_tol=1e-4)
159+
assert np.allclose(test_dml_benchmark_fixture["default_benchmark"],
160+
test_dml_benchmark_fixture["external_benchmark"],
161+
rtol=1e-9,
162+
atol=1e-4)

0 commit comments

Comments
 (0)