Skip to content

Commit a3789bc

Browse files
committed
Add test for ARIMA exogenous variable support
1 parent a67dd7d commit a3789bc

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

aeon/forecasting/stats/tests/test_arima.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,56 @@ def test_autoarima_forecast_is_consistent_with_wrapped():
190190
forecaster = AutoARIMA()
191191
val = forecaster._forecast(y)
192192
assert np.isclose(val, forecaster.final_model_.forecast_)
193+
194+
195+
def test_arima_with_exog_basic_fit_predict():
196+
y_local = np.arange(50, dtype=float)
197+
exog = np.random.RandomState(42).randn(50, 2)
198+
199+
model = ARIMA(p=1, d=0, q=1)
200+
model.fit(y_local, exog=exog)
201+
pred = model.predict(y_local, exog=exog[-1:].copy())
202+
203+
assert isinstance(pred, float)
204+
assert np.isfinite(pred)
205+
206+
207+
def test_arima_exog_shape_mismatch_raises():
208+
y_local = np.arange(20, dtype=float)
209+
exog = np.random.RandomState(0).randn(20, 3)
210+
211+
model = ARIMA(p=1, d=0, q=1)
212+
213+
with pytest.raises(ValueError):
214+
model.fit(y_local, exog=np.random.randn(10, 3))
215+
216+
model.fit(y_local, exog=exog)
217+
218+
with pytest.raises(ValueError):
219+
model.predict(y_local, exog=np.random.randn(1, 5))
220+
221+
222+
def test_arima_iterative_forecast_with_exog():
223+
y_local = np.arange(40, dtype=float)
224+
exog = np.random.RandomState(1).randn(40, 2)
225+
226+
model = ARIMA(p=1, d=1, q=1)
227+
model.fit(y_local, exog=exog)
228+
229+
h = 5
230+
future_exog = np.random.RandomState(2).randn(h, 2)
231+
preds = model.iterative_forecast(y_local, prediction_horizon=h, exog=future_exog)
232+
233+
assert preds.shape == (h,)
234+
assert np.all(np.isfinite(preds))
235+
236+
237+
def test_arima_no_exog_backward_compatibility():
238+
y_local = np.arange(30, dtype=float)
239+
240+
model = ARIMA(p=1, d=1, q=1)
241+
model.fit(y_local)
242+
pred = model.predict(y_local)
243+
244+
assert isinstance(pred, float)
245+
assert np.isfinite(pred)

0 commit comments

Comments
 (0)