Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e18ea75

Browse files
committedSep 2, 2024·
test: fix linting and some tests [WIP]
1 parent 554a754 commit e18ea75

File tree

11 files changed

+27
-19
lines changed

11 files changed

+27
-19
lines changed
 

‎README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ When the calibration terminates (~half a minute), towards the end of the output
127127
you should see the following messages:
128128
```
129129
True parameters: [0.2, 0.2, 0.75]
130-
Best parameters found: [0.19 0.21 0.68]
130+
Best parameters found: [0.21 0.19 0.76]
131131
```
132132

133133
## Docs

‎black_it/calibrator.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,11 @@ def __init__( # noqa: PLR0913
128128
# initialize arrays
129129
self.params_samp = np.zeros((0, self.param_grid.dims))
130130
self.losses_samp = np.zeros(0)
131-
self.batch_num_samp = np.zeros(0, dtype=int)
132-
self.method_samp = np.zeros(0, dtype=int)
133-
self.series_samp = np.zeros((0, self.ensemble_size, self.N, self.D))
131+
self.batch_num_samp: NDArray[np.int64] = np.zeros(0, dtype=int)
132+
self.method_samp: NDArray[np.int64] = np.zeros(0, dtype=int)
133+
self.series_samp: NDArray[np.float64] = np.zeros(
134+
(0, self.ensemble_size, self.N, self.D),
135+
)
134136

135137
# initialize variables before calibration
136138
self.n_sampled_params = 0

‎black_it/loss_functions/gsl_div.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def get_words(time_series: NDArray[np.float64], length: int) -> NDArray:
270270
"the chosen word length is too high",
271271
exception_class=ValueError,
272272
)
273-
tsw = np.zeros(shape=(tswlen,), dtype=np.int32)
273+
tsw: NDArray[np.float64] = np.zeros(shape=(tswlen,), dtype=np.int32)
274274

275275
for i in range(length):
276276
k = 10 ** (length - i - 1)

‎black_it/loss_functions/likelihood.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import warnings
21-
from typing import TYPE_CHECKING, Callable
21+
from typing import TYPE_CHECKING, Callable, cast
2222

2323
import numpy as np
2424

@@ -82,9 +82,13 @@ def compute_loss(
8282
Returns:
8383
The loss value.
8484
"""
85-
r = sim_data_ensemble.shape[0] # number of repetitions
86-
s = sim_data_ensemble.shape[1] # simulation length
87-
d = sim_data_ensemble.shape[2] # number of dimensions
85+
sim_data_ensemble_shape: tuple[int, int, int] = cast(
86+
tuple[int, int, int],
87+
sim_data_ensemble.shape,
88+
)
89+
r = sim_data_ensemble_shape[0] # number of repetitions
90+
s = sim_data_ensemble_shape[1] # simulation length
91+
d = sim_data_ensemble_shape[2] # time series dimension
8892

8993
if self.coordinate_weights is not None:
9094
warnings.warn( # noqa: B028

‎black_it/plot/plot_results.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import os
3434
from collections.abc import Collection
3535

36+
from numpy.typing import NDArray
37+
3638

3739
def _get_samplers_id_table(saving_folder: str | os.PathLike) -> dict[str, int]:
3840
"""Get the id table of the samplers from the checkpoint.
@@ -299,7 +301,7 @@ def plot_sampling_interact(saving_folder: str | os.PathLike) -> None:
299301
data_frame = pd.read_csv(calibration_results_file)
300302

301303
max_bn = int(max(data_frame["batch_num_samp"]))
302-
all_bns = np.arange(max_bn + 1, dtype=int)
304+
all_bns: NDArray[np.int64] = np.arange(max_bn + 1, dtype=int)
303305
indices_bns = np.array_split(all_bns, min(max_bn, 3))
304306

305307
dict_bns = {}

‎black_it/samplers/xgboost.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
if TYPE_CHECKING:
2929
from numpy.typing import NDArray
3030

31-
MAX_FLOAT32 = np.finfo(np.float32).max
32-
MIN_FLOAT32 = np.finfo(np.float32).min
33-
EPS_FLOAT32 = np.finfo(np.float32).eps
31+
MAX_FLOAT32: np.float64 = np.finfo(np.float32).max
32+
MIN_FLOAT32: np.float64 = np.finfo(np.float32).min
33+
EPS_FLOAT32: np.float64 = np.finfo(np.float32).eps
3434

3535

3636
class XGBoostSampler(MLSurrogateSampler):

‎black_it/search_space.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
self._param_grid: list[NDArray[np.float64]] = []
7373
self._space_size = 1
7474
for i in range(self.dims):
75-
new_col = np.arange(
75+
new_col: NDArray[np.float64] = np.arange(
7676
parameters_bounds[0][i],
7777
parameters_bounds[1][i] + 0.0000001,
7878
parameters_precision[i],
2.47 KB
Binary file not shown.

‎tests/test_examples/test_main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
TRUE_PARAMETERS_STR = "True parameters: [0.2, 0.2, 0.75]"
27-
BEST_PARAMETERS_STR = "Best parameters found: [0.19 0.21 0.68]"
27+
BEST_PARAMETERS_STR = "Best parameters found: [0.21 0.19 0.76]"
2828

2929

3030
class TestMainExample(BaseMainExampleTestClass):

‎tests/test_examples/test_sir_python.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,5 @@ def test_sir_w_breaks() -> None:
9292

9393
n = 100
9494
output = SIR_w_breaks(theta, n, seed=model_seed)
95-
95+
9696
assert np.isclose(output, expected_output).all()

‎tests/test_samplers/test_xgboost.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
else:
3535
expected_params = np.array([[0.24, 0.26], [0.37, 0.21], [0.43, 0.14], [0.11, 0.04]])
3636

37-
MAX_FLOAT32 = np.finfo(np.float32).max
38-
MIN_FLOAT32 = np.finfo(np.float32).min
39-
EPS_FLOAT32 = np.finfo(np.float32).eps
37+
MAX_FLOAT32: np.float64 = np.finfo(np.float32).max
38+
MIN_FLOAT32: np.float64 = np.finfo(np.float32).min
39+
EPS_FLOAT32: np.float64 = np.finfo(np.float32).eps
4040

4141

4242
def test_xgboost_2d() -> None:

0 commit comments

Comments
 (0)
Please sign in to comment.