Skip to content

Commit f126698

Browse files
committed
test: Add regression task 10-fold CV split integrity test.
1 parent 4b1bdf4 commit f126698

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# License: BSD 3-Clause
2+
from __future__ import annotations
3+
4+
import numpy as np
5+
import pytest
6+
import openml
7+
from openml.testing import TestBase
8+
9+
class OpenMLRegressionTaskSplitTest(TestBase):
10+
__test__ = True
11+
12+
def setUp(self):
13+
super().setUp()
14+
self.use_production_server()
15+
16+
@pytest.mark.production()
17+
def test_10_fold_cv_splits_integrity(self):
18+
# task 2280; regression; 10-fold cv
19+
task_id = 2280
20+
task = openml.tasks.get_task(task_id)
21+
22+
self.assertEqual(task.task_type_id, openml.tasks.TaskType.SUPERVISED_REGRESSION)
23+
24+
repeats, folds, samples = task.get_split_dimensions()
25+
self.assertEqual(folds, 10, "Task 2280 should have 10 folds")
26+
self.assertEqual(repeats, 1, "Task 2280 should have 1 repeat")
27+
28+
# track all test indices to ensure full coverage
29+
all_test_indices = set()
30+
31+
X, _ = task.get_X_and_y()
32+
n_instances = X.shape[0]
33+
34+
for fold in range(folds):
35+
train_indices, test_indices = task.get_train_test_split_indices(fold=fold)
36+
37+
self.assertIsInstance(train_indices, np.ndarray)
38+
self.assertIsInstance(test_indices, np.ndarray)
39+
40+
intersection = np.intersect1d(train_indices, test_indices)
41+
self.assertEqual(len(intersection), 0, f"Fold {fold}: Train and test indices overlap")
42+
43+
self.assertTrue(np.all(train_indices < n_instances), f"Fold {fold}: Train indices out of bounds")
44+
self.assertTrue(np.all(test_indices < n_instances), f"Fold {fold}: Test indices out of bounds")
45+
self.assertTrue(np.all(train_indices >= 0), f"Fold {fold}: Train indices negative")
46+
self.assertTrue(np.all(test_indices >= 0), f"Fold {fold}: Test indices negative")
47+
48+
all_test_indices.update(test_indices)
49+
50+
# assert that the union of all test sets covers the entire dataset
51+
# specific to cross validation (not holdout)
52+
self.assertEqual(len(all_test_indices), n_instances, "Union of all test sets should cover the entire dataset")
53+
expected_indices = set(range(n_instances))
54+
self.assertEqual(all_test_indices, expected_indices, "Test indices should match all instance indices")

0 commit comments

Comments
 (0)