1
1
import pytest
2
- import math
3
2
import numpy as np
4
3
import copy
5
4
6
5
import doubleml as dml
7
6
from doubleml import DoubleMLIRM , DoubleMLData
8
7
from doubleml .datasets import make_irm_data
9
- from sklearn .linear_model import LinearRegression
10
- from sklearn .ensemble import RandomForestRegressor , RandomForestClassifier
8
+ from sklearn .linear_model import LinearRegression , LogisticRegression
11
9
12
10
from ._utils_doubleml_sensitivity_manual import doubleml_sensitivity_manual , \
13
11
doubleml_sensitivity_benchmark_manual
@@ -114,18 +112,19 @@ def test_dml_sensitivity_benchmark(dml_sensitivity_multitreat_fixture):
114
112
@pytest .fixture (scope = "module" )
115
113
def test_dml_benchmark_fixture (benchmarking_set , n_rep ):
116
114
random_state = 42
117
- x , y , d = make_irm_data (n_obs = 10 , dim_x = 5 , theta = 0.5 , return_type = "np.array" )
115
+ x , y , d = make_irm_data (n_obs = 50 , dim_x = 5 , theta = 0 , return_type = "np.array" )
118
116
119
- classifier_class = RandomForestClassifier
120
- regressor_class = RandomForestRegressor
117
+ classifier_class = LogisticRegression
118
+ regressor_class = LinearRegression
121
119
122
120
np .random .seed (3141 )
123
121
dml_data = DoubleMLData .from_arrays (x = x , y = y , d = d )
124
122
x_list_long = copy .deepcopy (dml_data .x_cols )
125
123
dml_int = DoubleMLIRM (dml_data ,
126
124
ml_m = classifier_class (random_state = random_state ),
127
- ml_g = regressor_class (random_state = random_state ),
128
- n_folds = 2 )
125
+ ml_g = regressor_class (),
126
+ n_folds = 2 ,
127
+ n_rep = n_rep )
129
128
dml_int .fit (store_predictions = True )
130
129
dml_int .sensitivity_analysis ()
131
130
dml_ext = copy .deepcopy (dml_int )
@@ -136,8 +135,9 @@ def test_dml_benchmark_fixture(benchmarking_set, n_rep):
136
135
dml_data_short .x_cols = [x for x in x_list_long if x not in benchmarking_set ]
137
136
dml_short = DoubleMLIRM (dml_data_short ,
138
137
ml_m = classifier_class (random_state = random_state ),
139
- ml_g = regressor_class (random_state = random_state ),
140
- n_folds = 2 )
138
+ ml_g = regressor_class (),
139
+ n_folds = 2 ,
140
+ n_rep = n_rep )
141
141
dml_short .fit (store_predictions = True )
142
142
fit_args = {"external_predictions" : {"d" : {"ml_m" : dml_short .predictions ["ml_m" ][:, :, 0 ],
143
143
"ml_g0" : dml_short .predictions ["ml_g0" ][:, :, 0 ],
@@ -148,15 +148,15 @@ def test_dml_benchmark_fixture(benchmarking_set, n_rep):
148
148
dml_ext .sensitivity_analysis ()
149
149
df_bm_ext = dml_ext .sensitivity_benchmark (benchmarking_set = benchmarking_set , fit_args = fit_args )
150
150
151
- res_dict = {"default_benchmark" : df_bm . loc [ "d" , "delta_theta" ] ,
152
- "external_benchmark" : df_bm_ext . loc [ "d" , "delta_theta" ] }
151
+ res_dict = {"default_benchmark" : df_bm ,
152
+ "external_benchmark" : df_bm_ext }
153
153
154
154
return res_dict
155
155
156
156
157
157
@pytest .mark .ci
158
158
def test_dml_sensitivity_external_predictions (test_dml_benchmark_fixture ):
159
- assert math . isclose (test_dml_benchmark_fixture ["default_benchmark" ],
160
- test_dml_benchmark_fixture ["external_benchmark" ],
161
- rel_tol = 1e-9 ,
162
- abs_tol = 1e-4 )
159
+ assert np . allclose (test_dml_benchmark_fixture ["default_benchmark" ],
160
+ test_dml_benchmark_fixture ["external_benchmark" ],
161
+ rtol = 1e-9 ,
162
+ atol = 1e-4 )
0 commit comments