Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trying to solve issues with UNet and ASPPHeads. #457

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"):
mask = np.where(mask == metadata["nodata"], 1, 0)
mask = np.max(mask, axis=0)
result = np.where(mask == 1, -1, prediction.detach().cpu())

print(result.shape)
##### Save file to disk
metadata["count"] = 1
metadata["dtype"] = dtype
Expand Down
30 changes: 18 additions & 12 deletions terratorch/models/backbones/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from terratorch.models.backbones.utils import UpConvBlock, BasicConvBlock
from terratorch.models.decoders.utils import ConvModule

from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY

@TERRATORCH_BACKBONE_REGISTRY.register
class UNet(nn.Module):
"""UNet backbone.

Expand All @@ -22,7 +25,7 @@ class UNet(nn.Module):

Args:
in_channels (int): Number of input image channels. Default" 3.
base_channels (int): Number of base channels of each stage.
out_channels (int): Number of base channels of each stage.
The output channels of the first stage. Default: 64.
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
Expand Down Expand Up @@ -74,7 +77,7 @@ class UNet(nn.Module):

def __init__(self,
in_channels=3,
base_channels=64,
out_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
Expand Down Expand Up @@ -149,7 +152,7 @@ def __init__(self,
self.strides = strides
self.downsamples = downsamples
self.norm_eval = norm_eval
self.base_channels = base_channels
self.out_channels = num_stages*[out_channels]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is supposed to represent the out_channels per dec_outs item then it's not accurate, it should be:

self.out_channels = [out_channels * 2**i for i in reversed(range(num_stages))]

For out_channels = 32, num_stages = 5, self.out_channels = [512, 256, 128, 64, 32] which corresponds to the channel of items in dec_outs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Just filling space to see things running.


self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
Expand All @@ -163,9 +166,9 @@ def __init__(self,
self.decoder.append(
UpConvBlock(
conv_block=BasicConvBlock,
in_channels=base_channels * 2**i,
skip_channels=base_channels * 2**(i - 1),
out_channels=base_channels * 2**(i - 1),
in_channels=out_channels * 2**i,
skip_channels=out_channels * 2**(i - 1),
out_channels=out_channels * 2**(i - 1),
num_convs=dec_num_convs[i - 1],
stride=1,
dilation=dec_dilations[i - 1],
Expand All @@ -180,7 +183,7 @@ def __init__(self,
enc_conv_block.append(
BasicConvBlock(
in_channels=in_channels,
out_channels=base_channels * 2**i,
out_channels=out_channels * 2**i,
num_convs=enc_num_convs[i],
stride=strides[i],
dilation=enc_dilations[i],
Expand All @@ -191,11 +194,15 @@ def __init__(self,
dcn=None,
plugins=None))
self.encoder.append((nn.Sequential(*enc_conv_block)))
in_channels = base_channels * 2**i
in_channels = out_channels * 2**i

def forward(self, x):
x = x[0]
self._check_input_divisible(x)

# We can check just the first image, since the batch
# already was approved by the stackability test, which means
# all images has the same dimensions.
self._check_input_divisible(x[0])

enc_outs = []
for enc in self.encoder:
x = enc(x)
Expand All @@ -204,8 +211,7 @@ def forward(self, x):
for i in reversed(range(len(self.decoder))):
x = self.decoder[i](enc_outs[i], x)
dec_outs.append(x)

return dec_outs[-1]
return dec_outs

def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
Expand Down
8 changes: 7 additions & 1 deletion terratorch/models/decoders/aspp_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from torch import nn
import numpy as np

from terratorch.registry import TERRATORCH_DECODER_REGISTRY
from .utils import ConvModule, resize

@TERRATORCH_DECODER_REGISTRY.register
class ASPPModule(nn.Module):
"""Atrous Spatial Pyramid Pooling (ASPP) Module.

Expand Down Expand Up @@ -57,6 +59,7 @@ def forward(self, x):

return outs

@TERRATORCH_DECODER_REGISTRY.register
class ASPPHead(nn.Module):
"""Rethinking Atrous Convolution for Semantic Image Segmentation.

Expand Down Expand Up @@ -183,13 +186,15 @@ def _forward_feature(self, inputs):
H, W) which is feature map for last layer of decoder head.
"""
inputs = self._transform_inputs(inputs)

aspp_outs = [
resize(
self.image_pool(inputs),
size=inputs.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
]

aspp_outs.extend(self.aspp_modules(inputs))
aspp_outs = torch.cat(aspp_outs, dim=1)
feats = self.bottleneck(aspp_outs)
Expand All @@ -202,6 +207,7 @@ def forward(self, inputs):

return output

@TERRATORCH_DECODER_REGISTRY.register
class ASPPSegmentationHead(ASPPHead):
"""Rethinking Atrous Convolution for Semantic Image Segmentation.

Expand Down Expand Up @@ -255,6 +261,7 @@ def forward(self, inputs):

return output

@TERRATORCH_DECODER_REGISTRY.register
Copy link
Collaborator

@singhshraddha singhshraddha Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually for line 276 but I can only review changed lines

The way that ASPPRegressionHead is setup doesn't comply with the way encoder_decoder_factory.py calls decoder build. Since, the build expects that the first arguments to be channel_list and we have dilations instead. So, the dilations are set incorrectly and I cannot set dilations using a config

To resolve this I'd suggest

  • not accepting in_channels as input parameter
  • setting channel_list as the first input parameter
  • Using in_index with channel_list to set the in_channels
class ASPPRegressionHead(ASPPHead):

    def __init__(
        self, 
        channel_list,
        dilations:list | tuple =(1, 6, 12, 18), 
        channels:int=None,
        out_channels:int=1,
        align_corners=False,
        head_dropout_ratio:float=0.3,
        input_transform: str = None,
        in_index: int = -1,
        **kwargs
):

   self.in_channels = channel_list[in_index]

Note this only works when input_transform is None. I currently only care about this use case wherein we select one item from the channel list and this corresponds to the item we pick from the dec_outs. Though if you want to make it more generic mmseg has a BaseDecodeHead._init_inputs function that deals with different types of channel_list, input_transforms etc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I'll try to adapt and include it in the ASPP module.

class ASPPRegressionHead(ASPPHead):
"""Rethinking Atrous Convolution for regression.

Expand Down Expand Up @@ -293,7 +300,6 @@ def __init__(self, dilations:list | tuple =(1, 6, 12, 18),
def regression_head(self, features):

"""PixelWise regression"""

if self.dropout is not None:
features = self.dropout(features)
output = self.conv_reg(features)
Expand Down
1 change: 1 addition & 0 deletions terratorch/models/decoders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def resize(input,
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')

return F.interpolate(input, size, scale_factor, mode, align_corners)


Expand Down