Skip to content

Commit 174f2f1

Browse files
Testing to use strings to define a model
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 5dba481 commit 174f2f1

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

terratorch/datasets/generic_pixel_wise_dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
bands_type = self._bands_as_int_or_str(dataset_bands, output_bands)
136136

137137
if bands_type == str:
138-
raise UserWarning("When the bands are defined as str, guarantee your input files"+
138+
UserWarning("When the bands are defined as str, guarantee your input files"+
139139
"are organized by band and all have its specific name.")
140140

141141
if self.output_bands and not self.dataset_bands:
@@ -203,7 +203,7 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type:
203203
band_type[b] = str
204204
else:
205205
pass
206-
if band_type.cound(band_type[0]) == len(band_type):
206+
if band_type.count(band_type[0]) == len(band_type):
207207
return band_type[0]
208208
else:
209209
raise Exception("The bands must be or all str or all int.")

tests/test_finetune.py

+14
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ def test_finetune_bands_intervals(model_name):
3636
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"]
3737
_ = build_lightning_cli(command_list)
3838

39+
@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
40+
def test_finetune_bands_intervals(model_name):
41+
42+
model_instance = timm.create_model(model_name)
43+
44+
state_dict = model_instance.state_dict()
45+
46+
torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
47+
48+
# Running the terratorch CLI
49+
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"]
50+
_ = build_lightning_cli(command_list)
51+
52+
3953
"""
4054
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
4155
def test_finetune_multiple_backbones(model_name):

0 commit comments

Comments
 (0)