@@ -88,6 +88,7 @@ def __init__(
88
88
expected 0. Defaults to False.
89
89
"""
90
90
super ().__init__ ()
91
+
91
92
self .split_file = split
92
93
93
94
label_data_root = label_data_root if label_data_root is not None else data_root
@@ -136,7 +137,7 @@ def __init__(
136
137
if bands_type == str :
137
138
raise UserWarning ("When the bands are defined as str, guarantee your input files" +
138
139
"are organized by band and all have its specific name." )
139
-
140
+
140
141
if self .output_bands and not self .dataset_bands :
141
142
msg = "If output bands provided, dataset_bands must also be provided"
142
143
return Exception (msg ) # noqa: PLE0101
@@ -146,7 +147,9 @@ def __init__(
146
147
if len (set (self .output_bands ) & set (self .dataset_bands )) != len (self .output_bands ):
147
148
msg = "Output bands must be a subset of dataset bands"
148
149
raise Exception (msg )
150
+
149
151
self .filter_indices = [self .dataset_bands .index (band ) for band in self .output_bands ]
152
+
150
153
else :
151
154
self .filter_indices = None
152
155
@@ -176,7 +179,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
176
179
if self .transform :
177
180
output = self .transform (** output )
178
181
return output
179
-
182
+
180
183
def _load_file (self , path , nan_replace : int | float | None = None ) -> xr .DataArray :
181
184
data = rioxarray .open_rasterio (path , masked = True )
182
185
if nan_replace is not None :
@@ -200,7 +203,7 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type:
200
203
band_type [b ] = str
201
204
else :
202
205
pass
203
- if band_type .cound (band_type [0 ]) == len (band_type )
206
+ if band_type .cound (band_type [0 ]) == len (band_type ):
204
207
return band_type [0 ]
205
208
else :
206
209
raise Exception ("The bands must be or all str or all int." )
@@ -232,8 +235,8 @@ def __init__(
232
235
ignore_split_file_extensions : bool = True ,
233
236
allow_substring_split_file : bool = True ,
234
237
rgb_indices : list [str ] | None = None ,
235
- dataset_bands : list [HLSBands | int | list [int ]] | None = None ,
236
- output_bands : list [HLSBands | int | list [int ]] | None = None ,
238
+ dataset_bands : list [HLSBands | int | list [int ] | str ] | None = None ,
239
+ output_bands : list [HLSBands | int | list [int ] | str ] | None = None ,
237
240
class_names : list [str ] | None = None ,
238
241
constant_scale : float = 1 ,
239
242
transform : A .Compose | None = None ,
@@ -399,8 +402,8 @@ def __init__(
399
402
ignore_split_file_extensions : bool = True ,
400
403
allow_substring_split_file : bool = True ,
401
404
rgb_indices : list [int ] | None = None ,
402
- dataset_bands : list [HLSBands | int | list [int ]] | None = None ,
403
- output_bands : list [HLSBands | int | list [int ]] | None = None ,
405
+ dataset_bands : list [HLSBands | int | list [int ] | str ] | None = None ,
406
+ output_bands : list [HLSBands | int | list [int ] | str ] | None = None ,
404
407
constant_scale : float = 1 ,
405
408
transform : A .Compose | None = None ,
406
409
no_data_replace : float | None = None ,
0 commit comments