Skip to content

Commit a09f8e9

Browse files
fix bigearthnet and so2sat (#125)
Signed-off-by: Carlos Gomes <[email protected]>
1 parent 41e2f84 commit a09f8e9

File tree

2 files changed

+34
-36
lines changed

2 files changed

+34
-36
lines changed

terratorch/datasets/m_bigearthnet.py

+31-35
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
class MBigEarthNonGeo(NonGeoDataset):
17-
1817
all_band_names = (
1918
"COASTAL_AEROSOL",
2019
"BLUE",
@@ -28,77 +27,79 @@ class MBigEarthNonGeo(NonGeoDataset):
2827
"WATER_VAPOR",
2928
"SWIR_1",
3029
"SWIR_2",
31-
"CLOUD_PROBABILITY"
30+
"CLOUD_PROBABILITY",
3231
)
3332

3433
rgb_bands = ("RED", "GREEN", "BLUE")
3534

3635
BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}
3736

38-
def __init__(self, data_root: str, bands: Sequence[str] = BAND_SETS["all"], transform: A.Compose | None = None, split="train", **kwargs: any) -> None:
37+
def __init__(
38+
self,
39+
data_root: str,
40+
bands: Sequence[str] = BAND_SETS["all"],
41+
transform: A.Compose | None = None,
42+
split="train",
43+
partition="default",
44+
) -> None:
3945
super().__init__()
4046
if split not in ["train", "test", "val"]:
4147
msg = "Split must be one of train, test, val."
4248
raise Exception(msg)
4349
if split == "val":
4450
split = "valid"
45-
51+
4652
self.transform = transform if transform else lambda **batch: to_tensor(batch)
4753
self._validate_bands(bands)
4854
self.bands = bands
49-
self.band_indices = np.array(
50-
[self.all_band_names.index(b) for b in bands if b in self.all_band_names]
51-
)
55+
self.band_indices = np.array([self.all_band_names.index(b) for b in bands if b in self.all_band_names])
5256
self.split = split
5357
data_root = Path(data_root)
5458
self.data_directory = data_root / "m-bigearthnet"
55-
59+
5660
label_map_file = self.data_directory / "label_stats.json"
57-
with open(label_map_file, 'r') as file:
61+
with open(label_map_file, "r") as file:
5862
self.label_map = json.load(file)
5963

60-
partition_file = self.data_directory / "default_partition.json"
61-
with open(partition_file, 'r') as file:
64+
partition_file = self.data_directory / f"{partition}_partition.json"
65+
with open(partition_file, "r") as file:
6266
partitions = json.load(file)
6367

6468
if split not in partitions:
6569
raise ValueError(f"Split '{split}' not found.")
6670

6771
self.image_files = [self.data_directory / (filename + ".hdf5") for filename in partitions[split]]
6872

69-
7073
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
7174
file_path = self.image_files[index]
72-
image_id = file_path.stem
75+
image_id = file_path.stem
7376

74-
with h5py.File(file_path, 'r') as h5file:
77+
with h5py.File(file_path, "r") as h5file:
7578
keys = sorted(h5file.keys())
76-
keys = np.array([key for key in keys if key != 'label'])[self.band_indices]
79+
keys = np.array([key for key in keys if key != "label"])[self.band_indices]
7780
bands = [np.array(h5file[key]) for key in keys]
78-
81+
7982
image = np.stack(bands, axis=-1)
80-
83+
8184
labels_vector = self.label_map[image_id]
8285
labels_tensor = torch.tensor(labels_vector, dtype=torch.float)
8386

84-
output = {
85-
"image": image
86-
}
87+
output = {"image": image}
8788

8889
output = self.transform(**output)
8990

9091
output["label"] = labels_tensor
9192
return output
92-
93+
9394
def _validate_bands(self, bands: Sequence[str]) -> None:
9495
assert isinstance(bands, Sequence), "'bands' must be a sequence"
9596
for band in bands:
9697
if band not in self.all_band_names:
9798
raise ValueError(f"'{band}' is an invalid band name.")
98-
99+
99100
def __len__(self):
100101
return len(self.image_files)
101-
102+
102103
def plot(self, arg, suptitle: str | None = None) -> None:
103104
if isinstance(arg, int):
104105
sample = self.__getitem__(arg)
@@ -120,25 +121,20 @@ def plot(self, arg, suptitle: str | None = None) -> None:
120121
rgb_image = image[rgb_indices, :, :]
121122
rgb_image = np.transpose(rgb_image, (1, 2, 0))
122123
rgb_image = (rgb_image - np.min(rgb_image)) / (np.max(rgb_image) - np.min(rgb_image))
123-
124+
124125
active_labels = [i for i, label in enumerate(labels) if label == 1]
125126

126-
self._plot_sample(
127-
image=rgb_image,
128-
label_indices=active_labels,
129-
suptitle=suptitle
130-
)
131-
127+
self._plot_sample(image=rgb_image, label_indices=active_labels, suptitle=suptitle)
128+
132129
@staticmethod
133130
def _plot_sample(image, label_indices, suptitle=None) -> None:
134131
fig, ax = plt.subplots(figsize=(6, 6))
135132
ax.imshow(image)
136-
ax.axis('off')
133+
ax.axis("off")
137134

138-
title = f'Active Labels: {label_indices}'
135+
title = f"Active Labels: {label_indices}"
139136
if suptitle:
140-
title = f'{suptitle} - {title}'
137+
title = f"{suptitle} - {title}"
141138
ax.set_title(title)
142-
143-
return fig
144139

140+
return fig

terratorch/datasets/m_so2sat.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
8686
attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
8787
class_index = attr_dict["label"]
8888

89-
output = {"image": image.astype(np.float32), "label": class_index}
89+
output = {"image": image.astype(np.float32)}
9090

9191
output = self.transform(**output)
9292

93+
output["label"] = class_index
94+
9395
return output
9496

9597
def _validate_bands(self, bands: Sequence[str]) -> None:

0 commit comments

Comments
 (0)