Skip to content

Commit 5708e0a

Browse files
More tests to check if the bands ar properly returned
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent aedca32 commit 5708e0a

File tree

1 file changed

+101
-2
lines changed

1 file changed

+101
-2
lines changed

tests/test_generic_dataset.py

+101-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
SEGMENTATION_LABEL_PATH = "tests/segmentation_test_label.tif"
1515
NUM_CLASSES_SEGMENTATION = 2
1616

17-
1817
@pytest.fixture(scope="session")
1918
def split_file_path(tmp_path_factory):
2019
split_file_path = tmp_path_factory.mktemp("split") / "split.txt"
@@ -59,7 +58,6 @@ def test_data_type_regression_float_float(self, regression_dataset):
5958
assert torch.is_floating_point(regression_dataset[0]["image"])
6059
assert torch.is_floating_point(regression_dataset[0]["mask"])
6160

62-
6361
class TestGenericSegmentationDataset:
6462
@pytest.fixture(scope="class")
6563
def data_root_segmentation(self, tmp_path_factory: TempPathFactory):
@@ -94,3 +92,104 @@ def test_file_discovery_generic_segmentation_dataset(self, segmentation_dataset)
9492
def test_data_type_regression_float_long(self, segmentation_dataset):
9593
assert torch.is_floating_point(segmentation_dataset[0]["image"])
9694
assert not torch.is_floating_point(segmentation_dataset[0]["mask"])
95+
96+
# Testing bands
97+
# HLS_bands
98+
HLS_dataset_bands = [
99+
"COASTAL_AEROSOL",
100+
"BLUE",
101+
"GREEN",
102+
"RED",
103+
"NIR_NARROW",
104+
"SWIR_1",
105+
"SWIR_2",
106+
"CIRRUS",
107+
"THEMRAL_INFRARED_1",
108+
"THEMRAL_INFRARED_2",
109+
]
110+
111+
HLS_output_bands = [
112+
"BLUE",
113+
"GREEN",
114+
"RED",
115+
"NIR_NARROW",
116+
"SWIR_1",
117+
"SWIR_2",
118+
]
119+
120+
# Integer Intervals bands
121+
int_dataset_bands = (0,10)
122+
int_output_bands = (1,6)
123+
# Simple string bands
124+
str_dataset_bands = [f"band_{j}" for j in range(10)]
125+
str_output_bands = [f"band_{j}" for j in range(1,6)]
126+
127+
128+
class TestGenericDatasetWithBands:
129+
@pytest.fixture(scope="class")
130+
def data_root_regression(self, tmp_path_factory: TempPathFactory):
131+
data_dir = tmp_path_factory.mktemp("data")
132+
image_dir_path = data_dir / "input_data"
133+
label_dir_path = data_dir / "label_data"
134+
os.mkdir(image_dir_path)
135+
os.mkdir(label_dir_path)
136+
for i in range(10):
137+
os.symlink(REGRESSION_IMAGE_PATH, image_dir_path / f"{i}_img.tif")
138+
os.symlink(REGRESSION_LABEL_PATH, label_dir_path / f"{i}_label.tif")
139+
140+
# add a few with no suffix
141+
for i in range(10, 15):
142+
os.symlink(REGRESSION_IMAGE_PATH, image_dir_path / f"{i}.tif")
143+
os.symlink(REGRESSION_LABEL_PATH, label_dir_path / f"{i}.tif")
144+
return data_dir
145+
146+
@pytest.fixture(scope="class")
147+
def regression_dataset_with_HLS_bands(self, data_root_regression, split_file_path):
148+
return GenericNonGeoPixelwiseRegressionDataset(
149+
data_root_regression,
150+
dataset_bands=HLS_dataset_bands,
151+
output_bands=HLS_output_bands,
152+
image_grep="input_data/*_img.tif",
153+
label_grep="label_data/*_label.tif",
154+
split=split_file_path,
155+
)
156+
157+
@pytest.fixture(scope="class")
158+
def regression_dataset_with_interval_bands(self, data_root_regression, split_file_path):
159+
return GenericNonGeoPixelwiseRegressionDataset(
160+
data_root_regression,
161+
dataset_bands=[int_dataset_bands],
162+
output_bands=[int_output_bands],
163+
image_grep="input_data/*_img.tif",
164+
label_grep="label_data/*_label.tif",
165+
split=split_file_path,
166+
)
167+
168+
@pytest.fixture(scope="class")
169+
def regression_dataset_with_str_bands(self, data_root_regression, split_file_path):
170+
return GenericNonGeoPixelwiseRegressionDataset(
171+
data_root_regression,
172+
dataset_bands=str_dataset_bands,
173+
output_bands=str_output_bands,
174+
image_grep="input_data/*_img.tif",
175+
label_grep="label_data/*_label.tif",
176+
split=split_file_path,
177+
)
178+
179+
def test_usage_of_HLS_bands(self, regression_dataset_with_HLS_bands):
180+
181+
dataset = regression_dataset_with_HLS_bands
182+
assert dataset.output_bands == HLS_output_bands
183+
184+
def test_usage_of_interval_bands(self, regression_dataset_with_interval_bands):
185+
186+
dataset = regression_dataset_with_interval_bands
187+
int_output_bands_ = list(int_output_bands)
188+
int_output_bands_[1] += 1
189+
assert dataset.output_bands == list(range(*int_output_bands_))
190+
191+
def test_usage_of_str_bands(self, regression_dataset_with_str_bands):
192+
193+
dataset = regression_dataset_with_str_bands
194+
assert dataset.output_bands == str_output_bands
195+

0 commit comments

Comments
 (0)