Skip to content

Trying to solve issues with UNet and ASPPHeads. #457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions terratorch/models/backbones/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from terratorch.models.backbones.utils import UpConvBlock, BasicConvBlock
from terratorch.models.decoders.utils import ConvModule

from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY

@TERRATORCH_BACKBONE_REGISTRY.register
class UNet(nn.Module):
"""UNet backbone.

Expand All @@ -22,7 +25,7 @@ class UNet(nn.Module):

Args:
in_channels (int): Number of input image channels. Default" 3.
base_channels (int): Number of base channels of each stage.
out_channels (int): Number of base channels of each stage.
The output channels of the first stage. Default: 64.
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
Expand Down Expand Up @@ -74,7 +77,7 @@ class UNet(nn.Module):

def __init__(self,
in_channels=3,
base_channels=64,
out_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
Expand Down Expand Up @@ -149,7 +152,7 @@ def __init__(self,
self.strides = strides
self.downsamples = downsamples
self.norm_eval = norm_eval
self.base_channels = base_channels
self.out_channels = [out_channels * 2**i for i in reversed(range(num_stages))]

self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
Expand All @@ -163,9 +166,9 @@ def __init__(self,
self.decoder.append(
UpConvBlock(
conv_block=BasicConvBlock,
in_channels=base_channels * 2**i,
skip_channels=base_channels * 2**(i - 1),
out_channels=base_channels * 2**(i - 1),
in_channels=out_channels * 2**i,
skip_channels=out_channels * 2**(i - 1),
out_channels=out_channels * 2**(i - 1),
num_convs=dec_num_convs[i - 1],
stride=1,
dilation=dec_dilations[i - 1],
Expand All @@ -180,7 +183,7 @@ def __init__(self,
enc_conv_block.append(
BasicConvBlock(
in_channels=in_channels,
out_channels=base_channels * 2**i,
out_channels=out_channels * 2**i,
num_convs=enc_num_convs[i],
stride=strides[i],
dilation=enc_dilations[i],
Expand All @@ -191,11 +194,15 @@ def __init__(self,
dcn=None,
plugins=None))
self.encoder.append((nn.Sequential(*enc_conv_block)))
in_channels = base_channels * 2**i
in_channels = out_channels * 2**i

def forward(self, x):
x = x[0]
self._check_input_divisible(x)

# We can check just the first image, since the batch
# already was approved by the stackability test, which means
# all images has the same dimensions.
self._check_input_divisible(x[0])

enc_outs = []
for enc in self.encoder:
x = enc(x)
Expand All @@ -204,8 +211,7 @@ def forward(self, x):
for i in reversed(range(len(self.decoder))):
x = self.decoder[i](enc_outs[i], x)
dec_outs.append(x)

return dec_outs[-1]
return dec_outs

def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
Expand Down
14 changes: 11 additions & 3 deletions terratorch/models/decoders/aspp_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from torch import nn
import numpy as np

from terratorch.registry import TERRATORCH_DECODER_REGISTRY
from .utils import ConvModule, resize

@TERRATORCH_DECODER_REGISTRY.register
class ASPPModule(nn.Module):
"""Atrous Spatial Pyramid Pooling (ASPP) Module.

Expand Down Expand Up @@ -57,6 +59,7 @@ def forward(self, x):

return outs

@TERRATORCH_DECODER_REGISTRY.register
class ASPPHead(nn.Module):
"""Rethinking Atrous Convolution for Semantic Image Segmentation.

Expand Down Expand Up @@ -183,13 +186,15 @@ def _forward_feature(self, inputs):
H, W) which is feature map for last layer of decoder head.
"""
inputs = self._transform_inputs(inputs)

aspp_outs = [
resize(
self.image_pool(inputs),
size=inputs.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
]

aspp_outs.extend(self.aspp_modules(inputs))
aspp_outs = torch.cat(aspp_outs, dim=1)
feats = self.bottleneck(aspp_outs)
Expand All @@ -202,6 +207,7 @@ def forward(self, inputs):

return output

@TERRATORCH_DECODER_REGISTRY.register
class ASPPSegmentationHead(ASPPHead):
"""Rethinking Atrous Convolution for Semantic Image Segmentation.

Expand All @@ -213,7 +219,8 @@ class ASPPSegmentationHead(ASPPHead):
Default: (1, 6, 12, 18).
"""

def __init__(self, dilations:list | tuple =(1, 6, 12, 18),
def __init__(self, channel_list,
dilations:list | tuple =(1, 6, 12, 18),
in_channels:int=None,
channels:int=None,
num_classes:int=2,
Expand Down Expand Up @@ -255,6 +262,7 @@ def forward(self, inputs):

return output

@TERRATORCH_DECODER_REGISTRY.register
class ASPPRegressionHead(ASPPHead):
"""Rethinking Atrous Convolution for regression.

Expand All @@ -266,7 +274,8 @@ class ASPPRegressionHead(ASPPHead):
Default: (1, 6, 12, 18).
"""

def __init__(self, dilations:list | tuple =(1, 6, 12, 18),
def __init__(self, channel_list,
dilations:list | tuple =(1, 6, 12, 18),
in_channels:int=None,
channels:int=None,
out_channels:int=1,
Expand All @@ -293,7 +302,6 @@ def __init__(self, dilations:list | tuple =(1, 6, 12, 18),
def regression_head(self, features):

"""PixelWise regression"""

if self.dropout is not None:
features = self.dropout(features)
output = self.conv_reg(features)
Expand Down
1 change: 1 addition & 0 deletions terratorch/models/decoders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def resize(input,
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')

return F.interpolate(input, size, scale_factor, mode, align_corners)


Expand Down
1 change: 0 additions & 1 deletion terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
y = batch["label"]
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k: batch[k] for k in other_keys}

model_output: ModelOutput = self(x, **rest)
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand Down
1 change: 0 additions & 1 deletion terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k: batch[k] for k in other_keys}

model_output: ModelOutput = self(x, **rest)
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand Down
3 changes: 0 additions & 3 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
other_keys = batch.keys() - {"image", "mask", "filename"}

rest = {k: batch[k] for k in other_keys}

model_output: ModelOutput = self(x, **rest)
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
Expand All @@ -264,7 +263,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
other_keys = batch.keys() - {"image", "mask", "filename"}

rest = {k: batch[k] for k in other_keys}

model_output: ModelOutput = self(x, **rest)
if dataloader_idx >= len(self.test_loss_handler):
msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
Expand All @@ -291,7 +289,6 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)

loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
y_hat_hard = to_segmentation_prediction(model_output)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_aspphead():
dilations = (1, 6, 12, 18)
in_channels = 6
channels = 10
decoder = ASPPSegmentationHead(dilations=dilations, in_channels=in_channels, channels=channels, num_classes=2)
decoder = ASPPSegmentationHead([16, 32, 64, 128], dilations=dilations, in_channels=in_channels, channels=channels, num_classes=2)

image = [torch.from_numpy(np.random.rand(2, 6, 224, 224).astype("float32"))]

Expand Down