Skip to content

Commit 90d3e49

Browse files
Merge pull request #65 from IBM/pin/albumentations
pin albumentations
2 parents 9df5f5f + 9a263a8 commit 90d3e49

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

pyproject.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ dependencies = [
3939
"geobench>=1.0.0",
4040
"mlflow>=2.12.1",
4141
# broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977
42-
"lightning>=2, <=2.2.5"
42+
"lightning>=2, <=2.2.5",
43+
# see issue #64
44+
"albumentations<=1.4.10"
4345
]
4446

4547
[project.optional-dependencies]

terratorch/datasets/generic_pixel_wise_dataset.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,15 @@ def __getitem__(self, index: int) -> dict[str, Any]:
155155
"image": image.astype(np.float32) * self.constant_scale,
156156
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
157157
0
158-
],
159-
"filename": self.image_files[index],
158+
]
160159
}
161160

162161
if self.reduce_zero_label:
163162
output["mask"] -= 1
164163
if self.transform:
165164
output = self.transform(**output)
165+
output["filename"] = self.image_files[index]
166+
166167
return output
167168

168169
def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:

terratorch/datasets/generic_scalar_label_dataset.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,12 @@ def __getitem__(self, index: int) -> dict[str, Any]:
143143

144144
output = {
145145
"image": image.astype(np.float32) * self.constant_scale,
146-
"label": label,
147-
"filename": self.samples[index][
148-
0
149-
], # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
146+
"label": label, # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
150147
}
151148
if self.transforms:
152149
output = self.transforms(**output)
150+
output["filename"] = self.image_files[index]
151+
153152
return output
154153

155154
def _load_file(self, path) -> xr.DataArray:

0 commit comments

Comments
 (0)