@@ -174,6 +174,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
174
174
"mask" : self ._load_file (self .segmentation_mask_files [index ], nan_replace = self .no_label_replace ).to_numpy ()[0 ],
175
175
"filename" : self .image_files [index ],
176
176
}
177
+
177
178
if self .reduce_zero_label :
178
179
output ["mask" ] -= 1
179
180
if self .transform :
@@ -196,17 +197,20 @@ def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None):
196
197
def _bands_as_int_or_str (self , dataset_bands , output_bands ) -> type :
197
198
198
199
band_type = [None , None ]
199
- for b , bands_list in enumerate ([dataset_bands , output_bands ]):
200
- if all ([type (band )== int for band in bands_list ]):
201
- band_type [b ] = int
202
- elif all ([type (band )== str for band in bands_list ]):
203
- band_type [b ] = str
204
- else :
205
- pass
206
- if band_type .count (band_type [0 ]) == len (band_type ):
207
- return band_type [0 ]
200
+ if not dataset_bands and not output_bands :
201
+ return None
208
202
else :
209
- raise Exception ("The bands must be or all str or all int." )
203
+ for b , bands_list in enumerate ([dataset_bands , output_bands ]):
204
+ if all ([type (band )== int for band in bands_list ]):
205
+ band_type [b ] = int
206
+ elif all ([type (band )== str for band in bands_list ]):
207
+ band_type [b ] = str
208
+ else :
209
+ pass
210
+ if band_type .count (band_type [0 ]) == len (band_type ):
211
+ return band_type [0 ]
212
+ else :
213
+ raise Exception ("The bands must be or all str or all int." )
210
214
211
215
def _bands_defined_by_interval (self , bands_list : list [int ] | list [list [int ]] = None ) -> bool :
212
216
if not bands_list :
0 commit comments