|
1 | 1 | # Copyright contributors to the Terratorch project
|
2 | 2 |
|
3 |
| -"""Module containing generic dataset classes |
4 |
| -""" |
| 3 | +"""Module containing generic dataset classes""" |
| 4 | + |
5 | 5 | import glob
|
| 6 | +import operator |
6 | 7 | import os
|
7 | 8 | from abc import ABC
|
8 |
| -from functools import partial |
9 |
| -from pathlib import Path |
10 |
| -from typing import Any, List, Union |
11 | 9 | from functools import reduce
|
12 |
| -import operator |
| 10 | +from pathlib import Path |
| 11 | +from typing import Any |
| 12 | + |
13 | 13 | import albumentations as A
|
14 | 14 | import matplotlib as mpl
|
15 | 15 | import numpy as np
|
16 | 16 | import rioxarray
|
17 |
| -import torch |
18 | 17 | import xarray as xr
|
19 |
| -from albumentations.pytorch import ToTensorV2 |
20 | 18 | from einops import rearrange
|
21 |
| -from matplotlib import cm |
22 | 19 | from matplotlib import pyplot as plt
|
23 | 20 | from matplotlib.figure import Figure
|
24 | 21 | from matplotlib.patches import Rectangle
|
@@ -122,15 +119,8 @@ def __init__(
|
122 | 119 | )
|
123 | 120 | self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices
|
124 | 121 |
|
125 |
| - is_bands_by_interval = self._check_if_its_defined_by_interval(dataset_bands, output_bands) |
126 |
| - |
127 |
| - # If the bands are defined by sub-intervals or not. |
128 |
| - if is_bands_by_interval: |
129 |
| - self.dataset_bands = self._generate_bands_intervals(dataset_bands) |
130 |
| - self.output_bands = self._generate_bands_intervals(output_bands) |
131 |
| - else: |
132 |
| - self.dataset_bands = dataset_bands |
133 |
| - self.output_bands = output_bands |
| 122 | + self.dataset_bands = self._generate_bands_intervals(dataset_bands) |
| 123 | + self.output_bands = self._generate_bands_intervals(output_bands) |
134 | 124 |
|
135 | 125 | if self.output_bands and not self.dataset_bands:
|
136 | 126 | msg = "If output bands provided, dataset_bands must also be provided"
|
@@ -183,63 +173,25 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr
|
183 | 173 | data = data.fillna(nan_replace)
|
184 | 174 | return data
|
185 | 175 |
|
186 |
| - def _generate_bands_intervals(self, bands_intervals: List[List[int]] = None): |
187 |
| - bands = [] |
188 |
| - for b_interval in bands_intervals: |
189 |
| - bands_sublist = list(range(b_interval[0], b_interval[1] + 1)) |
190 |
| - bands.append(bands_sublist) |
191 |
| - return reduce(operator.iadd, bands, []) |
192 |
| - |
193 |
| - def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: |
194 |
| - |
195 |
| - band_type = [None, None] |
196 |
| - if not dataset_bands and not output_bands: |
| 176 | + def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None): |
| 177 | + if bands_intervals is None: |
197 | 178 | return None
|
198 |
| - else: |
199 |
| - for b, bands_list in enumerate([dataset_bands, output_bands]): |
200 |
| - if all([type(band) == int for band in bands_list]): |
201 |
| - band_type[b] = int |
202 |
| - elif all([type(band) == str for band in bands_list]): |
203 |
| - band_type[b] = str |
204 |
| - else: |
205 |
| - pass |
206 |
| - if band_type.count(band_type[0]) == len(band_type): |
207 |
| - return band_type[0] |
208 |
| - else: |
209 |
| - raise Exception("The bands must be or all str or all int.") |
210 |
| - |
211 |
| - def _check_if_its_defined_by_interval( |
212 |
| - self, dataset_bands: list[int] | list[tuple[int]] = None, output_bands: list[int] | list[tuple[int]] = None |
213 |
| - ) -> bool: |
214 |
| - |
215 |
| - is_dataset_bands_defined = self._bands_defined_by_interval(bands_list=dataset_bands) |
216 |
| - is_output_bands_defined = self._bands_defined_by_interval(bands_list=output_bands) |
217 |
| - |
218 |
| - if is_dataset_bands_defined and is_output_bands_defined: |
219 |
| - return True |
220 |
| - elif not is_dataset_bands_defined and not is_output_bands_defined: |
221 |
| - return False |
222 |
| - else: |
223 |
| - raise Exception( |
224 |
| - f"Both dataset_bands and output_bands must have the same type, but received {dataset_bands} and {output_bands}" |
225 |
| - ) |
226 |
| - |
227 |
| - def _bands_defined_by_interval(self, bands_list: list[int] | list[tuple[int]] = None) -> bool: |
228 |
| - if not bands_list: |
229 |
| - return False |
230 |
| - elif all([type(band) == int or type(band) == str or isinstance(band, HLSBands) for band in bands_list]): |
231 |
| - return False |
232 |
| - elif all([isinstance(subinterval, tuple) for subinterval in bands_list]): |
233 |
| - bands_list_ = [list(subinterval) for subinterval in bands_list] |
234 |
| - if all([type(band) == int for band in sum(bands_list_, [])]): |
235 |
| - return True |
| 179 | + bands = [] |
| 180 | + for element in bands_intervals: |
| 181 | + # if its an interval |
| 182 | + if isinstance(element, tuple): |
| 183 | + if len(element) != 2: # noqa: PLR2004 |
| 184 | + msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive" |
| 185 | + raise Exception(msg) |
| 186 | + expanded_element = list(range(element[0], element[1] + 1)) |
| 187 | + bands.extend(expanded_element) |
236 | 188 | else:
|
237 |
| - raise Exception(f"Whe using subintervals, the limits must be int.") |
238 |
| - else: |
239 |
| - raise Exception( |
240 |
| - f"Excpected List[int] or List[str] or List[tuple[int, int]], but received {type(bands_list)}." |
241 |
| - ) |
242 |
| - |
| 189 | + bands.append(element) |
| 190 | + # check the expansion didnt result in duplicate elements |
| 191 | + if len(set(bands)) != len(bands): |
| 192 | + msg = "Duplicate indices detected. Indices must be unique." |
| 193 | + raise Exception(msg) |
| 194 | + return bands |
243 | 195 |
|
244 | 196 | class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
|
245 | 197 | """GenericNonGeoSegmentationDataset"""
|
|
0 commit comments