9
9
from terratorch .datamodules .utils import wrap_in_compose_is_list
10
10
11
11
MEANS = {
12
- "COASTAL_AEROSOL" : 378.4027 ,
13
- "BLUE" : 482.2730 ,
14
- "GREEN" : 706.5345 ,
15
- "RED" : 720.9285 ,
16
- "RED_EDGE_1" : 1100.6688 ,
12
+ "COASTAL_AEROSOL" : 378.4027 ,
13
+ "BLUE" : 482.2730 ,
14
+ "GREEN" : 706.5345 ,
15
+ "RED" : 720.9285 ,
16
+ "RED_EDGE_1" : 1100.6688 ,
17
17
"RED_EDGE_2" : 1909.2914 ,
18
- "RED_EDGE_3" : 2191.6985 ,
19
- "NIR_BROAD" : 2336.8706 ,
20
- "NIR_NARROW" : 2394.7449 ,
21
- "WATER_VAPOR" : 2368.3127 ,
22
- "SWIR_1" : 1875.2487 ,
23
- "SWIR_2" : 1229.3818
18
+ "RED_EDGE_3" : 2191.6985 ,
19
+ "NIR_BROAD" : 2336.8706 ,
20
+ "NIR_NARROW" : 2394.7449 ,
21
+ "WATER_VAPOR" : 2368.3127 ,
22
+ "SWIR_1" : 1875.2487 ,
23
+ "SWIR_2" : 1229.3818 ,
24
24
}
25
25
26
- STDS = {
27
- "COASTAL_AEROSOL" : 157.5666 ,
28
- "BLUE" : 255.0429 ,
29
- "GREEN" : 303.1750 ,
30
- "RED" : 391.2943 ,
31
- "RED_EDGE_1" : 380.7916 ,
32
- "RED_EDGE_2" : 551.6558 ,
26
+ STDS = {
27
+ "COASTAL_AEROSOL" : 157.5666 ,
28
+ "BLUE" : 255.0429 ,
29
+ "GREEN" : 303.1750 ,
30
+ "RED" : 391.2943 ,
31
+ "RED_EDGE_1" : 380.7916 ,
32
+ "RED_EDGE_2" : 551.6558 ,
33
33
"RED_EDGE_3" : 638.8196 ,
34
34
"NIR_BROAD" : 744.2009 ,
35
- "NIR_NARROW" : 675.4041 ,
36
- "WATER_VAPOR" : 561.0154 ,
37
- "SWIR_1" : 563.4095 ,
38
- "SWIR_2" : 479.1786
35
+ "NIR_NARROW" : 675.4041 ,
36
+ "WATER_VAPOR" : 561.0154 ,
37
+ "SWIR_1" : 563.4095 ,
38
+ "SWIR_2" : 479.1786 ,
39
39
}
40
40
41
+
41
42
class MBigEarthNonGeoDataModule (NonGeoDataModule ):
42
43
def __init__ (
43
- self ,
44
- batch_size : int = 8 ,
45
- num_workers : int = 0 ,
44
+ self ,
45
+ batch_size : int = 8 ,
46
+ num_workers : int = 0 ,
46
47
data_root : str = "./" ,
47
48
train_transform : A .Compose | None | list [A .BasicTransform ] = None ,
48
49
val_transform : A .Compose | None | list [A .BasicTransform ] = None ,
49
50
test_transform : A .Compose | None | list [A .BasicTransform ] = None ,
50
51
aug : AugmentationSequential = None ,
51
- ** kwargs : Any
52
+ ** kwargs : Any ,
52
53
) -> None :
53
54
54
55
super ().__init__ (MBigEarthNonGeo , batch_size , num_workers , ** kwargs )
55
-
56
+
56
57
bands = kwargs .get ("bands" , MBigEarthNonGeo .all_band_names )
57
58
self .means = torch .tensor ([MEANS [b ] for b in bands ])
58
59
self .stds = torch .tensor ([STDS [b ] for b in bands ])
59
60
self .train_transform = wrap_in_compose_is_list (train_transform )
60
61
self .val_transform = wrap_in_compose_is_list (val_transform )
61
62
self .test_transform = wrap_in_compose_is_list (test_transform )
62
63
self .data_root = data_root
63
- self .aug = AugmentationSequential (K .Normalize (self .means , self .stds ), data_keys = ["image" ]) if aug is None else aug
64
-
64
+ self .aug = (
65
+ AugmentationSequential (K .Normalize (self .means , self .stds ), data_keys = ["image" ]) if aug is None else aug
66
+ )
67
+
65
68
def setup (self , stage : str ) -> None :
66
69
if stage in ["fit" ]:
67
- self .train_dataset = self .dataset_class (
70
+ self .train_dataset = self .dataset_class (
68
71
split = "train" , data_root = self .data_root , transform = self .train_transform , ** self .kwargs
69
72
)
70
73
if stage in ["fit" , "validate" ]:
@@ -73,5 +76,5 @@ def setup(self, stage: str) -> None:
73
76
)
74
77
if stage in ["test" ]:
75
78
self .test_dataset = self .dataset_class (
76
- split = "test" ,data_root = self .data_root , transform = self .test_transform , ** self .kwargs
79
+ split = "test" , data_root = self .data_root , transform = self .test_transform , ** self .kwargs
77
80
)
0 commit comments