Skip to content

Commit 46e7b9a

Browse files
committed
edit test to use less resources for github actions
Signed-off-by: Carlos Gomes <[email protected]>
1 parent 6430c3c commit 46e7b9a

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tests/test_prithvi_model_factory.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_create_classification_model_no_in_channels(backbone, model_factory: Pri
5656
with torch.no_grad():
5757
assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE
5858

59-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
59+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
6060
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
6161
def test_create_segmentation_model(backbone, decoder, model_factory: PrithviModelFactory, model_input):
6262
model = model_factory.build_model(
@@ -73,7 +73,7 @@ def test_create_segmentation_model(backbone, decoder, model_factory: PrithviMode
7373
with torch.no_grad():
7474
assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
7575

76-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
76+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
7777
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
7878
def test_create_segmentation_model_no_in_channels(backbone, decoder, model_factory: PrithviModelFactory, model_input):
7979
model = model_factory.build_model(
@@ -90,7 +90,7 @@ def test_create_segmentation_model_no_in_channels(backbone, decoder, model_facto
9090
assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
9191

9292

93-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
93+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
9494
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
9595
def test_create_segmentation_model_with_aux_heads(backbone, decoder, model_factory: PrithviModelFactory, model_input):
9696
aux_heads_name = ["first_aux", "second_aux"]
@@ -115,7 +115,7 @@ def test_create_segmentation_model_with_aux_heads(backbone, decoder, model_facto
115115
assert output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
116116

117117

118-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
118+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
119119
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
120120
def test_create_regression_model(backbone, decoder, model_factory: PrithviModelFactory, model_input):
121121
model = model_factory.build_model(
@@ -131,7 +131,7 @@ def test_create_regression_model(backbone, decoder, model_factory: PrithviModelF
131131
with torch.no_grad():
132132
assert model(model_input).output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
133133

134-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
134+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
135135
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
136136
def test_create_regression_model_no_in_channels(backbone, decoder, model_factory: PrithviModelFactory, model_input):
137137
model = model_factory.build_model(
@@ -146,7 +146,7 @@ def test_create_regression_model_no_in_channels(backbone, decoder, model_factory
146146
with torch.no_grad():
147147
assert model(model_input).output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
148148

149-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
149+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
150150
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
151151
def test_create_regression_model_with_aux_heads(backbone, decoder, model_factory: PrithviModelFactory, model_input):
152152
aux_heads_name = ["first_aux", "second_aux"]
@@ -170,7 +170,7 @@ def test_create_regression_model_with_aux_heads(backbone, decoder, model_factory
170170
assert output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
171171

172172

173-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
173+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
174174
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
175175
def test_create_model_with_extra_bands(backbone, decoder, model_factory: PrithviModelFactory):
176176
model = model_factory.build_model(

0 commit comments

Comments
 (0)