|
1 | 1 | from terratorch.datasets import HLSBands
|
2 | 2 |
|
3 |
| -def _are_sublists_of_int(item) -> bool: |
| 3 | +def _are_sublists_of_int(item) -> (bool, bool): |
4 | 4 |
|
5 | 5 | if all([isinstance(i, list) for i in item]):
|
6 | 6 | if all([isinstance(i, int) for i in sum(item, [])]):
|
7 |
| - return True |
| 7 | + return True, True |
8 | 8 | else:
|
9 |
| - return False |
| 9 | + raise Exception(f"It's expected sublists be [int, int], but rceived {model_bands}") |
| 10 | + elif len(item) == 2 and type(item[0]) == type(item[1]) == int: |
| 11 | + return False, True |
10 | 12 | else:
|
11 |
| - return False |
| 13 | + return False, False |
12 | 14 |
|
13 | 15 | def _estimate_in_chans(model_bands: list[HLSBands] | list[str] | tuple[int, int] = None) -> int:
|
14 | 16 |
|
15 | 17 | # Conditional to deal with the different possible choices for the bands
|
16 | 18 | # Bands as lists of strings or enum
|
17 |
| - if all([isinstance(b, str) or isinstance(b, HLSBands) for b in model_bands]): |
18 |
| - in_chans = len(model_bands) |
| 19 | + is_sublist, requires_special_eval = _are_sublists_of_int(model_bands) |
| 20 | + |
19 | 21 | # Bands as intervals limited by integers
|
20 |
| - elif all([isinstance(b, int) for b in model_bands] or _are_sublists_of_int(model_bands)): |
| 22 | + if requires_special_eval: |
21 | 23 |
|
22 |
| - if _are_sublists_of_int(model_bands): |
| 24 | + if is_sublist: |
23 | 25 | in_chans = sum([i[-1] - i[0] for i in model_bands])
|
24 | 26 | else:
|
25 | 27 | in_chans = model_bands[-1] - model_bands[0]
|
26 | 28 | else:
|
27 |
| - raise Exception(f"Expected bands to be list(str) or [int, int] but received {model_bands}") |
28 |
| - |
| 29 | + in_chans = len(model_bands) |
| 30 | + |
29 | 31 | return in_chans
|
30 | 32 |
|
31 | 33 |
|
0 commit comments