Skip to content

Commit a54d7e2

Browse files
authored
[MNT] Add missing load_model test for deep clusterers (Fixes #3080) (#3111)
* Add tests for loading/saving deep clusterers * Automatic `pre-commit` fixes * [MNT] Add comprehensive load/save tests for all deep clusterers * Improve deep clusterer load/save tests --------- Co-authored-by: satwiksps <[email protected]>
1 parent dc0c357 commit a54d7e2

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Tests for loading Keras models in deep autoencoder clusterers."""
2+
3+
import tempfile
4+
from pathlib import Path
5+
6+
import numpy as np
7+
import pytest
8+
9+
from aeon.clustering.deep_learning._ae_abgru import AEAttentionBiGRUClusterer
10+
from aeon.clustering.deep_learning._ae_bgru import AEBiGRUClusterer
11+
from aeon.clustering.deep_learning._ae_dcnn import AEDCNNClusterer
12+
from aeon.clustering.deep_learning._ae_drnn import AEDRNNClusterer
13+
from aeon.clustering.deep_learning._ae_fcn import AEFCNClusterer
14+
from aeon.clustering.deep_learning._ae_resnet import AEResNetClusterer
15+
from aeon.utils.validation._dependencies import _check_soft_dependencies
16+
17+
ALL_DEEP_CLUSTERERS = [
18+
AEAttentionBiGRUClusterer,
19+
AEBiGRUClusterer,
20+
AEDCNNClusterer,
21+
AEDRNNClusterer,
22+
AEFCNClusterer,
23+
AEResNetClusterer,
24+
]
25+
26+
27+
@pytest.mark.skipif(
28+
not _check_soft_dependencies("tensorflow", severity="none"),
29+
reason="TensorFlow not installed.",
30+
)
31+
@pytest.mark.parametrize("cls", ALL_DEEP_CLUSTERERS)
32+
def test_deep_clusterer_load_model(cls):
33+
"""Test that all deep autoencoder clusterers load saved Keras models correctly."""
34+
X = np.random.randn(4, 10, 1).astype(np.float32)
35+
params = cls._get_test_params()[0]
36+
params["n_epochs"] = 1
37+
params["save_best_model"] = True
38+
39+
with tempfile.TemporaryDirectory() as tmp:
40+
params["file_path"] = tmp + "/"
41+
model = cls(**params)
42+
model.fit(X)
43+
trained_estimator = model._estimator
44+
saved = list(Path(tmp).glob("*.keras"))
45+
assert saved, f"No .keras file saved for {cls.__name__}"
46+
model_path = str(saved[0])
47+
loaded = cls(**params)
48+
loaded.load_model(model_path, trained_estimator)
49+
assert loaded.model_ is not None, f"Loaded model_ is None for {cls.__name__}"
50+
assert hasattr(loaded.model_, "predict"), f"Invalid model_ for {cls.__name__}"
51+
preds = loaded.predict(X)
52+
assert preds.shape[0] == X.shape[0], f"Predict failed for {cls.__name__}"

0 commit comments

Comments
 (0)