Skip to content

Commit 1e00471

Browse files
Reformatting the source code using black
Signed-off-by: Joao Lucas de Sousa Almeida <[email protected]>
1 parent 308d540 commit 1e00471

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+738
-703
lines changed

examples/scripts/convert_sen1floods11_splits.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
data = np.genfromtxt(input_file, delimiter=',', dtype=str)
1414

15-
col1 = data[:,0].tolist()
15+
col1 = data[:, 0].tolist()
1616

1717
col1_ = ["_".join(i.split("_")[:2]) for i in col1]
1818

examples/scripts/instantiate_satmae_backbone.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1-
import torch
1+
import torch
22
import numpy as np
33

44
from models_mae import MaskedAutoencoderViT
55

6-
kwargs = {"img_size":224,
7-
"patch_size":16,
8-
"in_chans":3,
9-
"embed_dim":1024,
10-
"depth":24,
11-
"num_heads":16,
12-
"decoder_embed_dim":512,
13-
"decoder_depth":8,
14-
"decoder_num_heads":16,
15-
"mlp_ratio":4.}
6+
kwargs = {
7+
"img_size": 224,
8+
"patch_size": 16,
9+
"in_chans": 3,
10+
"embed_dim": 1024,
11+
"depth": 24,
12+
"num_heads": 16,
13+
"decoder_embed_dim": 512,
14+
"decoder_depth": 8,
15+
"decoder_num_heads": 16,
16+
"mlp_ratio": 4.0,
17+
}
1618

1719
vit_mae = MaskedAutoencoderViT(**kwargs)
1820

@@ -29,4 +31,3 @@
2931

3032
print(f"Output shape: {reconstructed.shape}")
3133
print("Done.")
32-

terratorch/cli_tools.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def write_tiff(img_wrt, filename, metadata):
8080
return filename
8181

8282

83-
def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"):
83+
def save_prediction(prediction, input_file_name, out_dir, dtype: str = "int16"):
8484
mask, metadata = open_tiff(input_file_name)
8585
mask = np.where(mask == metadata["nodata"], 1, 0)
8686
mask = np.max(mask, axis=0)
@@ -310,10 +310,11 @@ def instantiate_classes(self) -> None:
310310
config = self.config
311311
if hasattr(config, "predict_output_dir"):
312312
self.trainer.predict_output_dir = config.predict_output_dir
313-
313+
314314
if hasattr(config, "out_dtype"):
315315
self.trainer.out_dtype = config.out_dtype
316316

317+
317318
def build_lightning_cli(
318319
args: ArgsType = None,
319320
run=True, # noqa: FBT002
@@ -413,8 +414,12 @@ def from_config(
413414
]
414415

415416
if predict_dataset_bands is not None:
416-
arguments.extend([ "--data.init_args.predict_dataset_bands",
417-
"[" + ",".join(predict_dataset_bands) + "]",])
417+
arguments.extend(
418+
[
419+
"--data.init_args.predict_dataset_bands",
420+
"[" + ",".join(predict_dataset_bands) + "]",
421+
]
422+
)
418423

419424
cli = build_lightning_cli(arguments, run=False)
420425
trainer = cli.trainer

terratorch/datamodules/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,5 @@
5050
"MChesapeakeLandcoverNonGeoDataModule",
5151
"MPv4gerSegNonGeoDataModule",
5252
"MSACropTypeNonGeoDataModule",
53-
"MNeonTreeNonGeoDataModule"
53+
"MNeonTreeNonGeoDataModule",
5454
)

terratorch/datamodules/m_SA_crop_type.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"WATER_VAPOR": 69.904566,
2323
"SWIR_1": 83.626811,
2424
"SWIR_2": 65.767679,
25-
"CLOUD_PROBABILITY": 0.0
25+
"CLOUD_PROBABILITY": 0.0,
2626
}
2727

2828
STDS = {
@@ -38,36 +38,41 @@
3838
"WATER_VAPOR": 21.877766438821954,
3939
"SWIR_1": 28.14418826277069,
4040
"SWIR_2": 27.2346215312965,
41-
"CLOUD_PROBABILITY": 0.0
41+
"CLOUD_PROBABILITY": 0.0,
4242
}
4343

44+
4445
class MSACropTypeNonGeoDataModule(NonGeoDataModule):
4546
def __init__(
46-
self,
47-
batch_size: int = 8,
48-
num_workers: int = 0,
47+
self,
48+
batch_size: int = 8,
49+
num_workers: int = 0,
4950
data_root: str = "./",
5051
train_transform: A.Compose | None | list[A.BasicTransform] = None,
5152
val_transform: A.Compose | None | list[A.BasicTransform] = None,
5253
test_transform: A.Compose | None | list[A.BasicTransform] = None,
5354
aug: AugmentationSequential = None,
54-
**kwargs: Any
55+
**kwargs: Any,
5556
) -> None:
5657

5758
super().__init__(MSACropTypeNonGeo, batch_size, num_workers, **kwargs)
58-
59+
5960
bands = kwargs.get("bands", MSACropTypeNonGeo.all_band_names)
6061
self.means = torch.tensor([MEANS[b] for b in bands])
6162
self.stds = torch.tensor([STDS[b] for b in bands])
6263
self.train_transform = wrap_in_compose_is_list(train_transform)
6364
self.val_transform = wrap_in_compose_is_list(val_transform)
6465
self.test_transform = wrap_in_compose_is_list(test_transform)
6566
self.data_root = data_root
66-
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"]) if aug is None else aug
67-
67+
self.aug = (
68+
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"])
69+
if aug is None
70+
else aug
71+
)
72+
6873
def setup(self, stage: str) -> None:
6974
if stage in ["fit"]:
70-
self.train_dataset = self.dataset_class(
75+
self.train_dataset = self.dataset_class(
7176
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
7277
)
7378
if stage in ["fit", "validate"]:
@@ -76,5 +81,5 @@ def setup(self, stage: str) -> None:
7681
)
7782
if stage in ["test"]:
7883
self.test_dataset = self.dataset_class(
79-
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
84+
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
8085
)

terratorch/datamodules/m_bigearthnet.py

+34-31
Original file line numberDiff line numberDiff line change
@@ -9,62 +9,65 @@
99
from terratorch.datamodules.utils import wrap_in_compose_is_list
1010

1111
MEANS = {
12-
"COASTAL_AEROSOL": 378.4027,
13-
"BLUE": 482.2730,
14-
"GREEN": 706.5345,
15-
"RED": 720.9285,
16-
"RED_EDGE_1": 1100.6688,
12+
"COASTAL_AEROSOL": 378.4027,
13+
"BLUE": 482.2730,
14+
"GREEN": 706.5345,
15+
"RED": 720.9285,
16+
"RED_EDGE_1": 1100.6688,
1717
"RED_EDGE_2": 1909.2914,
18-
"RED_EDGE_3": 2191.6985,
19-
"NIR_BROAD": 2336.8706,
20-
"NIR_NARROW": 2394.7449,
21-
"WATER_VAPOR": 2368.3127,
22-
"SWIR_1": 1875.2487,
23-
"SWIR_2": 1229.3818
18+
"RED_EDGE_3": 2191.6985,
19+
"NIR_BROAD": 2336.8706,
20+
"NIR_NARROW": 2394.7449,
21+
"WATER_VAPOR": 2368.3127,
22+
"SWIR_1": 1875.2487,
23+
"SWIR_2": 1229.3818,
2424
}
2525

26-
STDS = {
27-
"COASTAL_AEROSOL": 157.5666,
28-
"BLUE": 255.0429,
29-
"GREEN": 303.1750,
30-
"RED": 391.2943,
31-
"RED_EDGE_1": 380.7916,
32-
"RED_EDGE_2": 551.6558,
26+
STDS = {
27+
"COASTAL_AEROSOL": 157.5666,
28+
"BLUE": 255.0429,
29+
"GREEN": 303.1750,
30+
"RED": 391.2943,
31+
"RED_EDGE_1": 380.7916,
32+
"RED_EDGE_2": 551.6558,
3333
"RED_EDGE_3": 638.8196,
3434
"NIR_BROAD": 744.2009,
35-
"NIR_NARROW": 675.4041,
36-
"WATER_VAPOR": 561.0154,
37-
"SWIR_1": 563.4095,
38-
"SWIR_2": 479.1786
35+
"NIR_NARROW": 675.4041,
36+
"WATER_VAPOR": 561.0154,
37+
"SWIR_1": 563.4095,
38+
"SWIR_2": 479.1786,
3939
}
4040

41+
4142
class MBigEarthNonGeoDataModule(NonGeoDataModule):
4243
def __init__(
43-
self,
44-
batch_size: int = 8,
45-
num_workers: int = 0,
44+
self,
45+
batch_size: int = 8,
46+
num_workers: int = 0,
4647
data_root: str = "./",
4748
train_transform: A.Compose | None | list[A.BasicTransform] = None,
4849
val_transform: A.Compose | None | list[A.BasicTransform] = None,
4950
test_transform: A.Compose | None | list[A.BasicTransform] = None,
5051
aug: AugmentationSequential = None,
51-
**kwargs: Any
52+
**kwargs: Any,
5253
) -> None:
5354

5455
super().__init__(MBigEarthNonGeo, batch_size, num_workers, **kwargs)
55-
56+
5657
bands = kwargs.get("bands", MBigEarthNonGeo.all_band_names)
5758
self.means = torch.tensor([MEANS[b] for b in bands])
5859
self.stds = torch.tensor([STDS[b] for b in bands])
5960
self.train_transform = wrap_in_compose_is_list(train_transform)
6061
self.val_transform = wrap_in_compose_is_list(val_transform)
6162
self.test_transform = wrap_in_compose_is_list(test_transform)
6263
self.data_root = data_root
63-
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
64-
64+
self.aug = (
65+
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
66+
)
67+
6568
def setup(self, stage: str) -> None:
6669
if stage in ["fit"]:
67-
self.train_dataset = self.dataset_class(
70+
self.train_dataset = self.dataset_class(
6871
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
6972
)
7073
if stage in ["fit", "validate"]:
@@ -73,5 +76,5 @@ def setup(self, stage: str) -> None:
7376
)
7477
if stage in ["test"]:
7578
self.test_dataset = self.dataset_class(
76-
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
79+
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
7780
)

terratorch/datamodules/m_brick_kiln.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"WATER_VAPOR": 1129.8171906000355,
2323
"CIRRUS": 83.27188605598549,
2424
"SWIR_1": 90.54924599052214,
25-
"SWIR_2": 68.98768652434848
25+
"SWIR_2": 68.98768652434848,
2626
}
2727

2828
STDS = {
@@ -38,37 +38,39 @@
3838
"WATER_VAPOR": 704.0219637458916,
3939
"CIRRUS": 36.355745901131705,
4040
"SWIR_1": 28.004671947623894,
41-
"SWIR_2": 24.268892726362033
41+
"SWIR_2": 24.268892726362033,
4242
}
4343

44-
class MBrickKilnNonGeoDataModule(NonGeoDataModule):
4544

45+
class MBrickKilnNonGeoDataModule(NonGeoDataModule):
4646
def __init__(
47-
self,
48-
batch_size: int = 8,
49-
num_workers: int = 0,
47+
self,
48+
batch_size: int = 8,
49+
num_workers: int = 0,
5050
data_root: str = "./",
5151
train_transform: A.Compose | None | list[A.BasicTransform] = None,
5252
val_transform: A.Compose | None | list[A.BasicTransform] = None,
5353
test_transform: A.Compose | None | list[A.BasicTransform] = None,
5454
aug: AugmentationSequential = None,
55-
**kwargs: Any
55+
**kwargs: Any,
5656
) -> None:
5757

5858
super().__init__(MBrickKilnNonGeo, batch_size, num_workers, **kwargs)
59-
59+
6060
bands = kwargs.get("bands", MBrickKilnNonGeo.all_band_names)
6161
self.means = torch.tensor([MEANS[b] for b in bands])
6262
self.stds = torch.tensor([STDS[b] for b in bands])
6363
self.train_transform = wrap_in_compose_is_list(train_transform)
6464
self.val_transform = wrap_in_compose_is_list(val_transform)
6565
self.test_transform = wrap_in_compose_is_list(test_transform)
6666
self.data_root = data_root
67-
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
68-
67+
self.aug = (
68+
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
69+
)
70+
6971
def setup(self, stage: str) -> None:
7072
if stage in ["fit"]:
71-
self.train_dataset = self.dataset_class(
73+
self.train_dataset = self.dataset_class(
7274
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
7375
)
7476
if stage in ["fit", "validate"]:
@@ -77,5 +79,5 @@ def setup(self, stage: str) -> None:
7779
)
7880
if stage in ["test"]:
7981
self.test_dataset = self.dataset_class(
80-
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
82+
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
8183
)

terratorch/datamodules/m_cashew_plantation.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"WATER_VAPOR": 2852.87451171875,
2323
"SWIR_1": 2463.933349609375,
2424
"SWIR_2": 1600.9207763671875,
25-
"CLOUD_PROBABILITY": 0.010281000286340714
25+
"CLOUD_PROBABILITY": 0.010281000286340714,
2626
}
2727

2828
STDS = {
@@ -38,37 +38,41 @@
3838
"WATER_VAPOR": 413.8980407714844,
3939
"SWIR_1": 494.97430419921875,
4040
"SWIR_2": 514.4229736328125,
41-
"CLOUD_PROBABILITY": 0.3447800576686859
41+
"CLOUD_PROBABILITY": 0.3447800576686859,
4242
}
4343

44-
class MBeninSmallHolderCashewsNonGeoDataModule(NonGeoDataModule):
4544

45+
class MBeninSmallHolderCashewsNonGeoDataModule(NonGeoDataModule):
4646
def __init__(
47-
self,
48-
batch_size: int = 8,
49-
num_workers: int = 0,
47+
self,
48+
batch_size: int = 8,
49+
num_workers: int = 0,
5050
data_root: str = "./",
5151
train_transform: A.Compose | None | list[A.BasicTransform] = None,
5252
val_transform: A.Compose | None | list[A.BasicTransform] = None,
5353
test_transform: A.Compose | None | list[A.BasicTransform] = None,
5454
aug: AugmentationSequential = None,
55-
**kwargs: Any
55+
**kwargs: Any,
5656
) -> None:
5757

5858
super().__init__(MBeninSmallHolderCashewsNonGeo, batch_size, num_workers, **kwargs)
59-
59+
6060
bands = kwargs.get("bands", MBeninSmallHolderCashewsNonGeo.all_band_names)
6161
self.means = torch.tensor([MEANS[b] for b in bands])
6262
self.stds = torch.tensor([STDS[b] for b in bands])
6363
self.train_transform = wrap_in_compose_is_list(train_transform)
6464
self.val_transform = wrap_in_compose_is_list(val_transform)
6565
self.test_transform = wrap_in_compose_is_list(test_transform)
6666
self.data_root = data_root
67-
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"]) if aug is None else aug
68-
67+
self.aug = (
68+
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"])
69+
if aug is None
70+
else aug
71+
)
72+
6973
def setup(self, stage: str) -> None:
7074
if stage in ["fit"]:
71-
self.train_dataset = self.dataset_class(
75+
self.train_dataset = self.dataset_class(
7276
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
7377
)
7478
if stage in ["fit", "validate"]:
@@ -77,5 +81,5 @@ def setup(self, stage: str) -> None:
7781
)
7882
if stage in ["test"]:
7983
self.test_dataset = self.dataset_class(
80-
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
81-
)
84+
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
85+
)

0 commit comments

Comments
 (0)