15
15
# along with this program. If not, see <http://www.gnu.org/licenses/>.
16
16
17
17
"""This module contains tests for the Calibrator.calibrate method."""
18
+ import sys
18
19
from pathlib import Path
19
20
from typing import Any
20
21
from unittest .mock import MagicMock , patch
42
43
class TestCalibrate : # pylint: disable=too-many-instance-attributes,attribute-defined-outside-init
43
44
"""Test the Calibrator.calibrate method."""
44
45
46
+ expected_params = np .array (
47
+ [
48
+ [0.59 , 0.36 ],
49
+ [0.63 , 0.41 ],
50
+ [0.18 , 0.39 ],
51
+ [0.56 , 0.37 ],
52
+ [0.83 , 0.35 ],
53
+ [0.54 , 0.32 ],
54
+ [0.74 , 0.32 ],
55
+ [0.53 , 0.46 ],
56
+ [0.57 , 0.39 ],
57
+ [0.94 , 0.42 ],
58
+ [0.32 , 0.93 ],
59
+ [0.8 , 0.06 ],
60
+ [0.01 , 0.02 ],
61
+ [0.04 , 0.99 ],
62
+ ]
63
+ )
64
+
65
+ expected_losses = [
66
+ 0.33400294 ,
67
+ 0.55274918 ,
68
+ 0.55798021 ,
69
+ 0.61712034 ,
70
+ 0.91962075 ,
71
+ 1.31118518 ,
72
+ 1.51682355 ,
73
+ 1.55503666 ,
74
+ 1.65968375 ,
75
+ 1.78845827 ,
76
+ 1.79905545 ,
77
+ 2.07605975 ,
78
+ 2.28484134 ,
79
+ 3.01432484 ,
80
+ ]
81
+
82
+ win32_expected_params = np .array (
83
+ [
84
+ [0.59 , 0.36 ],
85
+ [0.63 , 0.41 ],
86
+ [0.18 , 0.39 ],
87
+ [0.56 , 0.37 ],
88
+ [0.83 , 0.35 ],
89
+ [0.54 , 0.32 ],
90
+ [0.74 , 0.32 ],
91
+ [0.53 , 0.46 ],
92
+ [0.57 , 0.39 ],
93
+ [0.32 , 0.93 ],
94
+ [0.8 , 0.06 ],
95
+ [0.01 , 0.02 ],
96
+ [1.0 , 0.99 ],
97
+ [0.04 , 0.99 ],
98
+ ]
99
+ )
100
+
101
+ win32_expected_losses = [
102
+ 0.33400294 ,
103
+ 0.55274918 ,
104
+ 0.55798021 ,
105
+ 0.61712034 ,
106
+ 0.91962075 ,
107
+ 1.31118518 ,
108
+ 1.51682355 ,
109
+ 1.55503666 ,
110
+ 1.65968375 ,
111
+ 1.79905545 ,
112
+ 2.07605975 ,
113
+ 2.28484134 ,
114
+ 2.60093616 ,
115
+ 3.01432484 ,
116
+ ]
117
+
45
118
def setup (self ) -> None :
46
119
"""Set up the tests."""
47
120
self .true_params = np .array ([0.50 , 0.50 ])
@@ -75,42 +148,6 @@ def setup(self) -> None:
75
148
@pytest .mark .parametrize ("n_jobs" , [1 , 2 ])
76
149
def test_calibrator_calibrate (self , n_jobs : int ) -> None :
77
150
"""Test the Calibrator.calibrate method, positive case, with different number of jobs."""
78
- expected_params = np .array (
79
- [
80
- [0.59 , 0.36 ],
81
- [0.63 , 0.41 ],
82
- [0.18 , 0.39 ],
83
- [0.56 , 0.37 ],
84
- [0.83 , 0.35 ],
85
- [0.54 , 0.32 ],
86
- [0.74 , 0.32 ],
87
- [0.53 , 0.46 ],
88
- [0.57 , 0.39 ],
89
- [0.94 , 0.42 ],
90
- [0.32 , 0.93 ],
91
- [0.8 , 0.06 ],
92
- [0.01 , 0.02 ],
93
- [0.04 , 0.99 ],
94
- ]
95
- )
96
-
97
- expected_losses = [
98
- 0.33400294 ,
99
- 0.55274918 ,
100
- 0.55798021 ,
101
- 0.61712034 ,
102
- 0.91962075 ,
103
- 1.31118518 ,
104
- 1.51682355 ,
105
- 1.55503666 ,
106
- 1.65968375 ,
107
- 1.78845827 ,
108
- 1.79905545 ,
109
- 2.07605975 ,
110
- 2.28484134 ,
111
- 3.01432484 ,
112
- ]
113
-
114
151
cal = Calibrator (
115
152
samplers = [
116
153
self .random_sampler ,
@@ -134,8 +171,14 @@ def test_calibrator_calibrate(self, n_jobs: int) -> None:
134
171
135
172
params , losses = cal .calibrate (2 )
136
173
137
- assert np .allclose (params , expected_params )
138
- assert np .allclose (losses , expected_losses )
174
+ # TODO: this is a temporary workaround to make tests to run also on Windows. # pylint: disable=fixme
175
+ # See: https://github.com/bancaditalia/black-it/issues/49
176
+ if sys .platform == "win32" :
177
+ assert np .allclose (params , self .win32_expected_params )
178
+ assert np .allclose (losses , self .win32_expected_losses )
179
+ else :
180
+ assert np .allclose (params , self .expected_params )
181
+ assert np .allclose (losses , self .expected_losses )
139
182
140
183
def test_calibrator_with_check_convergence (self , capsys : Any ) -> None :
141
184
"""Test the Calibrator.calibrate method with convergence check."""
0 commit comments