Skip to content

Commit 194a27a

Browse files
Testing FCNDecoder
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 125c148 commit 194a27a

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

tests/test_decoders.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import unittest
2+
import pytest
23

34
import torch
45
from torch import nn
56

67
from terratorch.models.decoders.upernet_decoder import ConvModule
7-
from terratorch.models.decoders import UperNetDecoder
8+
from terratorch.models.decoders import FCNDecoder
89

910
class TestConvModule(unittest.TestCase):
1011
def setUp(self):
@@ -32,4 +33,31 @@ def test_conv_weight_shape(self):
3233
def test_norm_weight_shape(self):
3334
self.assertEqual(self.module.norm.weight.shape, (self.out_channels,))
3435

36+
class TestFCNDecoder(unittest.TestCase):
37+
38+
def test_fcn_decoder(self):
39+
# create inputs
40+
batch_size = 32
41+
height = 32
42+
width = 32
43+
num_channels = 3
44+
embed_dim = [64, 128, 256]
45+
num_convs = 4
46+
in_index = -1
47+
48+
# create model
49+
decoder = FCNDecoder(
50+
embed_dim=embed_dim,
51+
channels=num_channels,
52+
num_convs=num_convs,
53+
in_index=in_index
54+
)
55+
56+
# create input tensor
57+
x = torch.rand((batch_size, embed_dim[in_index], height, width))
58+
# get output shape
59+
out = decoder([None, x])
60+
out_shape = out.shape
3561

62+
# check output shape
63+
self.assertEqual(out_shape, (batch_size, num_channels, (2**num_convs)*height, (2**num_convs)*width))

0 commit comments

Comments
 (0)