Skip to content

Commit 125c148

Browse files
Testing ConvModule decoder
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 52a3c13 commit 125c148

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/test_decoders.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
3+
import torch
4+
from torch import nn
5+
6+
from terratorch.models.decoders.upernet_decoder import ConvModule
7+
from terratorch.models.decoders import UperNetDecoder
8+
9+
class TestConvModule(unittest.TestCase):
10+
def setUp(self):
11+
self.in_channels = 3
12+
self.out_channels = 64
13+
self.kernel_size = 3
14+
self.padding = 1
15+
self.inplace = True
16+
self.batch_size = 8
17+
self.input_shape = (self.batch_size, self.in_channels, 256, 256)
18+
19+
self.module = ConvModule(
20+
self.in_channels, self.out_channels, self.kernel_size, self.padding, self.inplace
21+
)
22+
23+
self.input = torch.rand(self.input_shape)
24+
25+
def test_forward(self):
26+
output = self.module(self.input)
27+
self.assertEqual(output.shape, self.input_shape[:1] + (self.out_channels,) + output.shape[2:])
28+
29+
def test_conv_weight_shape(self):
30+
self.assertEqual(self.module.conv.weight.shape, (self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
31+
32+
def test_norm_weight_shape(self):
33+
self.assertEqual(self.module.norm.weight.shape, (self.out_channels,))
34+
35+

0 commit comments

Comments
 (0)