Skip to content

Commit 680963f

Browse files
committed
add label after transform
Signed-off-by: Carlos Gomes <[email protected]>
1 parent 88b4231 commit 680963f

File tree

5 files changed

+14
-6
lines changed

5 files changed

+14
-6
lines changed

terratorch/datasets/m_bigearthnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
8282
labels_tensor = torch.tensor(labels_vector, dtype=torch.float)
8383

8484
output = {
85-
"image": image,
86-
"label": labels_tensor
85+
"image": image
8786
}
8887

8988
output = self.transform(**output)
9089

90+
output["label"] = labels_tensor
9191
return output
9292

9393
def _validate_bands(self, bands: Sequence[str]) -> None:

terratorch/datasets/m_brick_kiln.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
8181
attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
8282
class_index = attr_dict["label"]
8383

84-
output = {"image": image.astype(np.float32), "label": class_index}
84+
output = {"image": image.astype(np.float32)}
8585

8686
output = self.transform(**output)
8787

88+
output["label"] = class_index
89+
8890
return output
8991

9092
def __len__(self):

terratorch/datasets/m_eurosat.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
8686
label_class = self.id_to_class[image_id]
8787
label_index = list(self.label_map.keys()).index(label_class)
8888

89-
output = {"image": image.astype(np.float32), "label": label_index}
89+
output = {"image": image.astype(np.float32)}
9090

9191
output = self.transform(**output)
9292

93+
output["label"] = label_index
94+
9395
return output
9496

9597
def _validate_bands(self, bands: Sequence[str]) -> None:

terratorch/datasets/m_forestnet.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
7474
attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
7575
class_index = attr_dict["label"]
7676

77-
output = {"image": image.astype(np.float32), "label": class_index}
77+
output = {"image": image.astype(np.float32)}
7878

7979
output = self.transform(**output)
8080

81+
output["label"] = class_index
82+
8183
return output
8284

8385
def _validate_bands(self, bands: Sequence[str]) -> None:

terratorch/datasets/m_pv4ger.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
6767
attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
6868
class_index = attr_dict["label"]
6969

70-
output = {"image": image.astype(np.float32), "label": class_index}
70+
output = {"image": image.astype(np.float32)}
7171

7272
output = self.transform(**output)
7373

74+
output["label"] = class_index
75+
7476
return output
7577

7678
def _validate_bands(self, bands: Sequence[str]) -> None:

0 commit comments

Comments
 (0)