Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion aeon/classification/distance_based/_elastic_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class ElasticEnsemble(BaseClassifier):
A ``list`` of strings identifying which distance measures to include. Valid values
are one or more of: ``euclidean``, ``dtw``, ``wdtw``, ``ddtw``, ``wddtw``,
``lcss``, ``erp``, ``msm``, ``twe``. The default value ``all`` means that all
the previously listed distances are used.
the previously listed distances are used. The special value ``ts-quad`` can be
used to select the distance measures for the TS-QUAD ensemble: WDTW, DDTW,
LCSS, and MSM.
proportion_of_param_options : float, default=1
The proportion of the parameter grid space to search optional.
proportion_train_in_param_finding : float, default=1
Expand Down Expand Up @@ -153,6 +155,17 @@ def _fit(self, X, y):
"euclidean",
"twe",
]
elif self.distance_measures == "ts-quad":
self._distance_measures = [
"wdtw",
"ddtw",
"lcss",
"msm",
]
if self.verbose > 0:
print( # noqa: T201
"Configuring ElasticEnsemble as TS-QUAD with WDTW, DDTW, LCSS, MSM."
)
else:
self._distance_measures = self.distance_measures

Expand Down Expand Up @@ -515,6 +528,13 @@ def _get_test_params(cls, parameter_set: str = "default") -> dict | list[dict]:
"majority_vote": True,
"distance_measures": ["dtw", "ddtw", "wdtw"],
}
elif parameter_set == "ts-quad":
return {
"proportion_of_param_options": 0.01,
"proportion_train_for_test": 0.1,
"majority_vote": True,
"distance_measures": "ts-quad",
}
else:
return {
"proportion_of_param_options": 0.01,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,17 @@ def test_all_distance_measures():
ee.fit(X, y)
distances = list(ee.get_metric_params())
assert len(distances) == 9


def test_ts_quad_distance_measures():
"""Test the 'ts-quad' option of the distance_measures parameter."""
X = np.random.random(size=(10, 1, 10))
y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
ee = ElasticEnsemble(
distance_measures="ts-quad",
proportion_train_in_param_finding=0.2,
proportion_of_param_options=0.1,
)
ee.fit(X, y)
actual_distances = list(ee.get_metric_params())
assert len(actual_distances) == 4
1 change: 1 addition & 0 deletions docs/changelogs/v1.3.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ September 2025

### Enhancements

- [ENH] Implement TS-QUAD as a distance parameter for `ElasticEnsemble` ({pr}`3126`) {user}`Nithurshen`
- [ENH] Improvements to ST transformer and classifier ({pr}`2968`) {user}`MatthewMiddlehurst`
- [ENH] KNN n_jobs and updated kneighbours method ({pr}`2578`) {user}`chrisholder`
- [ENH] Refactor signature code ({pr}`2943`) {user}`TonyBagnall`
Expand Down