Skip to content

Commit 9a263a8

Browse files
committed
add filename only after transforms
Signed-off-by: Carlos Gomes <[email protected]>
1 parent 3cf13b6 commit 9a263a8

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

terratorch/datasets/generic_pixel_wise_dataset.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,15 @@ def __getitem__(self, index: int) -> dict[str, Any]:
155155
"image": image.astype(np.float32) * self.constant_scale,
156156
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
157157
0
158-
],
159-
"filename": self.image_files[index],
158+
]
160159
}
161160

162161
if self.reduce_zero_label:
163162
output["mask"] -= 1
164163
if self.transform:
165164
output = self.transform(**output)
165+
output["filename"] = self.image_files[index]
166+
166167
return output
167168

168169
def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:

terratorch/datasets/generic_scalar_label_dataset.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,12 @@ def __getitem__(self, index: int) -> dict[str, Any]:
143143

144144
output = {
145145
"image": image.astype(np.float32) * self.constant_scale,
146-
"label": label,
147-
"filename": self.samples[index][
148-
0
149-
], # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
146+
"label": label, # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
150147
}
151148
if self.transforms:
152149
output = self.transforms(**output)
150+
output["filename"] = self.image_files[index]
151+
153152
return output
154153

155154
def _load_file(self, path) -> xr.DataArray:

0 commit comments

Comments
 (0)