66
77from aeon .forecasting .utils ._extract_paras import _extract_arma_params
88from aeon .forecasting .utils ._loss_functions import _arima_fit
9- from aeon .forecasting .utils ._nelder_mead import dispatch_loss , nelder_mead
9+ from aeon .forecasting .utils ._nelder_mead import nelder_mead
1010
1111
1212@pytest .mark .parametrize (
@@ -29,63 +29,63 @@ def test_arima_fit(params, data, model, expected_aic):
2929 ), f"AIC mismatch. Got { result } , expected { expected_aic } "
3030
3131
32- @pytest .mark .parametrize (
33- "fn_id, params, data, model, expected_result" ,
34- [
35- (
36- 0 ,
37- np .array ([0.5 , - 0.1 , 0.1 ]),
38- np .array ([1.0 , 2.0 , 1.5 , 1.7 , 2.1 ]),
39- np .array ([1 , 1 , 1 ]),
40- 19.99880 , # example expected result from _arima_fit
41- ),
42- (
43- 1 ,
44- np .array ([0.5 , 0.3 , 1 , 0.4 ]),
45- np .array ([3 , 10 , 12 , 13 , 12 , 10 , 12 , 3 , 10 , 12 , 13 , 12 , 10 , 12 ]),
46- np .array ([1 , 1 , 1 , 4 ]),
47- 55.58355126510806 ,
48- ),
49- (
50- 1 ,
51- np .array ([0.7 , 0.6 , 0.97 , 0.1 ]),
52- np .array ([3 , 10 , 12 , 13 , 12 , 10 , 12 , 3 , 10 , 12 , 13 , 12 , 10 , 12 ]),
53- np .array ([2 , 1 , 1 , 4 ]),
54- 61.797186036891276 ,
55- ),
56- (
57- 1 ,
58- np .array ([0.4 , 0.2 , 0.8 , 0.5 ]),
59- np .array ([3 , 10 , 12 , 13 , 12 , 10 , 12 , 3 , 10 , 12 , 13 , 12 , 10 , 12 ]),
60- np .array ([1 , 2 , 2 , 4 ]),
61- 76.86950158342418 ,
62- ),
63- (
64- 1 ,
65- np .array ([0.7 , 0.5 , 0.85 , 0.2 ]),
66- np .array ([3 , 10 , 12 , 13 , 12 , 10 , 12 , 3 , 10 , 12 , 13 , 12 , 10 , 12 ]),
67- np .array ([2 , 2 , 2 , 4 ]),
68- 82.83246015454237 ,
69- ),
70- (
71- 2 ,
72- np .array ([0.0 ]),
73- np .array ([1.0 , 1.0 , 1.0 , 1.0 ]),
74- np .array ([0 , 0 , 0 ]),
75- ValueError , # expected error for unknown fn_id
76- ),
77- ],
78- )
79- def test_dispatch_loss (fn_id , params , data , model , expected_result ):
80- """Test dispatching loss functions by function ID."""
81- if isinstance (expected_result , type ) and issubclass (expected_result , Exception ):
82- with pytest .raises (expected_result ):
83- dispatch_loss (fn_id , params , data , model )
84- else :
85- result = dispatch_loss (fn_id , params , data , model )
86- assert np .isclose (
87- result , expected_result , atol = 1e-4
88- ), f"Result mismatch. Got { result } , expected { expected_result } "
32+ # @pytest.mark.parametrize(
33+ # "fn_id, params, data, model, expected_result",
34+ # [
35+ # (
36+ # 0,
37+ # np.array([0.5, -0.1, 0.1]),
38+ # np.array([1.0, 2.0, 1.5, 1.7, 2.1]),
39+ # np.array([1, 1, 1]),
40+ # 19.99880, # example expected result from _arima_fit
41+ # ),
42+ # (
43+ # 1,
44+ # np.array([0.5, 0.3, 1, 0.4]),
45+ # np.array([3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12]),
46+ # np.array([1, 1, 1, 4]),
47+ # 55.58355126510806,
48+ # ),
49+ # (
50+ # 1,
51+ # np.array([0.7, 0.6, 0.97, 0.1]),
52+ # np.array([3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12]),
53+ # np.array([2, 1, 1, 4]),
54+ # 61.797186036891276,
55+ # ),
56+ # (
57+ # 1,
58+ # np.array([0.4, 0.2, 0.8, 0.5]),
59+ # np.array([3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12]),
60+ # np.array([1, 2, 2, 4]),
61+ # 76.86950158342418,
62+ # ),
63+ # (
64+ # 1,
65+ # np.array([0.7, 0.5, 0.85, 0.2]),
66+ # np.array([3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12]),
67+ # np.array([2, 2, 2, 4]),
68+ # 82.83246015454237,
69+ # ),
70+ # (
71+ # 2,
72+ # np.array([0.0]),
73+ # np.array([1.0, 1.0, 1.0, 1.0]),
74+ # np.array([0, 0, 0]),
75+ # ValueError, # expected error for unknown fn_id
76+ # ),
77+ # ],
78+ # )
79+ # def test_dispatch_loss(fn_id, params, data, model, expected_result):
80+ # """Test dispatching loss functions by function ID."""
81+ # if isinstance(expected_result, type) and issubclass(expected_result, Exception):
82+ # with pytest.raises(expected_result):
83+ # dispatch_loss(fn_id, params, data, model)
84+ # else:
85+ # result = dispatch_loss(fn_id, params, data, model)
86+ # assert np.isclose(
87+ # result, expected_result, atol=1e-4
88+ # ), f"Result mismatch. Got {result}, expected {expected_result}"
8989
9090
9191@pytest .mark .parametrize (
0 commit comments