14
14
15
15
16
16
class MBigEarthNonGeo (NonGeoDataset ):
17
-
18
17
all_band_names = (
19
18
"COASTAL_AEROSOL" ,
20
19
"BLUE" ,
@@ -28,77 +27,79 @@ class MBigEarthNonGeo(NonGeoDataset):
28
27
"WATER_VAPOR" ,
29
28
"SWIR_1" ,
30
29
"SWIR_2" ,
31
- "CLOUD_PROBABILITY"
30
+ "CLOUD_PROBABILITY" ,
32
31
)
33
32
34
33
rgb_bands = ("RED" , "GREEN" , "BLUE" )
35
34
36
35
BAND_SETS = {"all" : all_band_names , "rgb" : rgb_bands }
37
36
38
- def __init__ (self , data_root : str , bands : Sequence [str ] = BAND_SETS ["all" ], transform : A .Compose | None = None , split = "train" , ** kwargs : any ) -> None :
37
+ def __init__ (
38
+ self ,
39
+ data_root : str ,
40
+ bands : Sequence [str ] = BAND_SETS ["all" ],
41
+ transform : A .Compose | None = None ,
42
+ split = "train" ,
43
+ partition = "default" ,
44
+ ) -> None :
39
45
super ().__init__ ()
40
46
if split not in ["train" , "test" , "val" ]:
41
47
msg = "Split must be one of train, test, val."
42
48
raise Exception (msg )
43
49
if split == "val" :
44
50
split = "valid"
45
-
51
+
46
52
self .transform = transform if transform else lambda ** batch : to_tensor (batch )
47
53
self ._validate_bands (bands )
48
54
self .bands = bands
49
- self .band_indices = np .array (
50
- [self .all_band_names .index (b ) for b in bands if b in self .all_band_names ]
51
- )
55
+ self .band_indices = np .array ([self .all_band_names .index (b ) for b in bands if b in self .all_band_names ])
52
56
self .split = split
53
57
data_root = Path (data_root )
54
58
self .data_directory = data_root / "m-bigearthnet"
55
-
59
+
56
60
label_map_file = self .data_directory / "label_stats.json"
57
- with open (label_map_file , 'r' ) as file :
61
+ with open (label_map_file , "r" ) as file :
58
62
self .label_map = json .load (file )
59
63
60
- partition_file = self .data_directory / "default_partition .json"
61
- with open (partition_file , 'r' ) as file :
64
+ partition_file = self .data_directory / f" { partition } _partition .json"
65
+ with open (partition_file , "r" ) as file :
62
66
partitions = json .load (file )
63
67
64
68
if split not in partitions :
65
69
raise ValueError (f"Split '{ split } ' not found." )
66
70
67
71
self .image_files = [self .data_directory / (filename + ".hdf5" ) for filename in partitions [split ]]
68
72
69
-
70
73
def __getitem__ (self , index : int ) -> dict [str , torch .Tensor ]:
71
74
file_path = self .image_files [index ]
72
- image_id = file_path .stem
75
+ image_id = file_path .stem
73
76
74
- with h5py .File (file_path , 'r' ) as h5file :
77
+ with h5py .File (file_path , "r" ) as h5file :
75
78
keys = sorted (h5file .keys ())
76
- keys = np .array ([key for key in keys if key != ' label' ])[self .band_indices ]
79
+ keys = np .array ([key for key in keys if key != " label" ])[self .band_indices ]
77
80
bands = [np .array (h5file [key ]) for key in keys ]
78
-
81
+
79
82
image = np .stack (bands , axis = - 1 )
80
-
83
+
81
84
labels_vector = self .label_map [image_id ]
82
85
labels_tensor = torch .tensor (labels_vector , dtype = torch .float )
83
86
84
- output = {
85
- "image" : image
86
- }
87
+ output = {"image" : image }
87
88
88
89
output = self .transform (** output )
89
90
90
91
output ["label" ] = labels_tensor
91
92
return output
92
-
93
+
93
94
def _validate_bands (self , bands : Sequence [str ]) -> None :
94
95
assert isinstance (bands , Sequence ), "'bands' must be a sequence"
95
96
for band in bands :
96
97
if band not in self .all_band_names :
97
98
raise ValueError (f"'{ band } ' is an invalid band name." )
98
-
99
+
99
100
def __len__ (self ):
100
101
return len (self .image_files )
101
-
102
+
102
103
def plot (self , arg , suptitle : str | None = None ) -> None :
103
104
if isinstance (arg , int ):
104
105
sample = self .__getitem__ (arg )
@@ -120,25 +121,20 @@ def plot(self, arg, suptitle: str | None = None) -> None:
120
121
rgb_image = image [rgb_indices , :, :]
121
122
rgb_image = np .transpose (rgb_image , (1 , 2 , 0 ))
122
123
rgb_image = (rgb_image - np .min (rgb_image )) / (np .max (rgb_image ) - np .min (rgb_image ))
123
-
124
+
124
125
active_labels = [i for i , label in enumerate (labels ) if label == 1 ]
125
126
126
- self ._plot_sample (
127
- image = rgb_image ,
128
- label_indices = active_labels ,
129
- suptitle = suptitle
130
- )
131
-
127
+ self ._plot_sample (image = rgb_image , label_indices = active_labels , suptitle = suptitle )
128
+
132
129
@staticmethod
133
130
def _plot_sample (image , label_indices , suptitle = None ) -> None :
134
131
fig , ax = plt .subplots (figsize = (6 , 6 ))
135
132
ax .imshow (image )
136
- ax .axis (' off' )
133
+ ax .axis (" off" )
137
134
138
- title = f' Active Labels: { label_indices } '
135
+ title = f" Active Labels: { label_indices } "
139
136
if suptitle :
140
- title = f' { suptitle } - { title } '
137
+ title = f" { suptitle } - { title } "
141
138
ax .set_title (title )
142
-
143
- return fig
144
139
140
+ return fig
0 commit comments