Skip to content

Commit 3d9b5f5

Browse files
Better conditional for all the kinds of bands
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 52f0c12 commit 3d9b5f5

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

terratorch/models/backbones/utils.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,33 @@
11
from terratorch.datasets import HLSBands
22

3-
def _are_sublists_of_int(item) -> bool:
3+
def _are_sublists_of_int(item) -> (bool, bool):
44

55
if all([isinstance(i, list) for i in item]):
66
if all([isinstance(i, int) for i in sum(item, [])]):
7-
return True
7+
return True, True
88
else:
9-
return False
9+
raise Exception(f"It's expected sublists be [int, int], but rceived {model_bands}")
10+
elif len(item) == 2 and type(item[0]) == type(item[1]) == int:
11+
return False, True
1012
else:
11-
return False
13+
return False, False
1214

1315
def _estimate_in_chans(model_bands: list[HLSBands] | list[str] | tuple[int, int] = None) -> int:
1416

1517
# Conditional to deal with the different possible choices for the bands
1618
# Bands as lists of strings or enum
17-
if all([isinstance(b, str) or isinstance(b, HLSBands) for b in model_bands]):
18-
in_chans = len(model_bands)
19+
is_sublist, requires_special_eval = _are_sublists_of_int(model_bands)
20+
1921
# Bands as intervals limited by integers
20-
elif all([isinstance(b, int) for b in model_bands] or _are_sublists_of_int(model_bands)):
22+
if requires_special_eval:
2123

22-
if _are_sublists_of_int(model_bands):
24+
if is_sublist:
2325
in_chans = sum([i[-1] - i[0] for i in model_bands])
2426
else:
2527
in_chans = model_bands[-1] - model_bands[0]
2628
else:
27-
raise Exception(f"Expected bands to be list(str) or [int, int] but received {model_bands}")
28-
29+
in_chans = len(model_bands)
30+
2931
return in_chans
3032

3133

0 commit comments

Comments
 (0)