Skip to content

Commit 3aefe1b

Browse files
This function must be shared
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 9f840ea commit 3aefe1b

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

terratorch/datamodules/generic_pixel_wise_data_module.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,12 @@
1717
from torchgeo.transforms import AugmentationSequential
1818

1919
from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands
20-
20+
from terratorch.io.file import load_from_file_or_attribute
2121

2222
def wrap_in_compose_is_list(transform_list):
2323
# set check shapes to false because of the multitemporal case
2424
return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list
2525

26-
def load_from_file_or_attribute(value:list[float] | str):
27-
28-
if type(value) == list:
29-
return value
30-
elif type(str): # It can be the path for a file
31-
if os.path.isfile(value):
32-
try:
33-
content = np.genfromtxt(value).tolist()
34-
except:
35-
raise Exception(f"File must be txt, but received {value}")
36-
else:
37-
raise Exception("It seems that {value} does not exist or is not a file.")
38-
39-
return content
4026

4127
# def collate_fn_list_dicts(batch):
4228
# metadata = []

terratorch/io/file.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import importlib
33
from torch import nn
4+
import numpy as np
45

56
def open_generic_torch_model(model: type | str = None,
67
model_kwargs: dict = None,
@@ -51,3 +52,21 @@ def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = N
5152
)
5253

5354
return model
55+
56+
def load_from_file_or_attribute(value: list[float]|str):
57+
58+
if isinstance(value, list):
59+
return value
60+
elif isinstance(value, str): # It can be the path for a file
61+
if os.path.isfile(value):
62+
try:
63+
print(value)
64+
content = np.genfromtxt(value).tolist()
65+
except:
66+
raise Exception(f"File must be txt, but received {value}")
67+
else:
68+
raise Exception(f"The input {value} does not exist or is not a file.")
69+
70+
return content
71+
72+

0 commit comments

Comments
 (0)