|
| 1 | +"""Tests for loading Keras models in deep autoencoder clusterers.""" |
| 2 | + |
| 3 | +import tempfile |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import pytest |
| 8 | + |
| 9 | +from aeon.clustering.deep_learning._ae_abgru import AEAttentionBiGRUClusterer |
| 10 | +from aeon.clustering.deep_learning._ae_bgru import AEBiGRUClusterer |
| 11 | +from aeon.clustering.deep_learning._ae_dcnn import AEDCNNClusterer |
| 12 | +from aeon.clustering.deep_learning._ae_drnn import AEDRNNClusterer |
| 13 | +from aeon.clustering.deep_learning._ae_fcn import AEFCNClusterer |
| 14 | +from aeon.clustering.deep_learning._ae_resnet import AEResNetClusterer |
| 15 | +from aeon.utils.validation._dependencies import _check_soft_dependencies |
| 16 | + |
| 17 | +ALL_DEEP_CLUSTERERS = [ |
| 18 | + AEAttentionBiGRUClusterer, |
| 19 | + AEBiGRUClusterer, |
| 20 | + AEDCNNClusterer, |
| 21 | + AEDRNNClusterer, |
| 22 | + AEFCNClusterer, |
| 23 | + AEResNetClusterer, |
| 24 | +] |
| 25 | + |
| 26 | + |
| 27 | +@pytest.mark.skipif( |
| 28 | + not _check_soft_dependencies("tensorflow", severity="none"), |
| 29 | + reason="TensorFlow not installed.", |
| 30 | +) |
| 31 | +@pytest.mark.parametrize("cls", ALL_DEEP_CLUSTERERS) |
| 32 | +def test_deep_clusterer_load_model(cls): |
| 33 | + """Test that all deep autoencoder clusterers load saved Keras models correctly.""" |
| 34 | + X = np.random.randn(4, 10, 1).astype(np.float32) |
| 35 | + params = cls._get_test_params()[0] |
| 36 | + params["n_epochs"] = 1 |
| 37 | + params["save_best_model"] = True |
| 38 | + |
| 39 | + with tempfile.TemporaryDirectory() as tmp: |
| 40 | + params["file_path"] = tmp + "/" |
| 41 | + model = cls(**params) |
| 42 | + model.fit(X) |
| 43 | + trained_estimator = model._estimator |
| 44 | + saved = list(Path(tmp).glob("*.keras")) |
| 45 | + assert saved, f"No .keras file saved for {cls.__name__}" |
| 46 | + model_path = str(saved[0]) |
| 47 | + loaded = cls(**params) |
| 48 | + loaded.load_model(model_path, trained_estimator) |
| 49 | + assert loaded.model_ is not None, f"Loaded model_ is None for {cls.__name__}" |
| 50 | + assert hasattr(loaded.model_, "predict"), f"Invalid model_ for {cls.__name__}" |
| 51 | + preds = loaded.predict(X) |
| 52 | + assert preds.shape[0] == X.shape[0], f"Predict failed for {cls.__name__}" |
0 commit comments