-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trying to solve issues with UNet and ASPPHeads. #457
base: main
Are you sure you want to change the base?
Changes from 10 commits
0cff8af
83d04e7
c2837df
6c5510c
12bf967
ca2dd74
ed3ebfa
8d58ec3
a0b550e
555d800
b502907
9e8cddf
d0a3e98
6dec32c
9bb5dae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,10 @@ | |
from torch import nn | ||
import numpy as np | ||
|
||
from terratorch.registry import TERRATORCH_DECODER_REGISTRY | ||
from .utils import ConvModule, resize | ||
|
||
@TERRATORCH_DECODER_REGISTRY.register | ||
class ASPPModule(nn.Module): | ||
"""Atrous Spatial Pyramid Pooling (ASPP) Module. | ||
|
||
|
@@ -57,6 +59,7 @@ def forward(self, x): | |
|
||
return outs | ||
|
||
@TERRATORCH_DECODER_REGISTRY.register | ||
class ASPPHead(nn.Module): | ||
"""Rethinking Atrous Convolution for Semantic Image Segmentation. | ||
|
||
|
@@ -183,13 +186,15 @@ def _forward_feature(self, inputs): | |
H, W) which is feature map for last layer of decoder head. | ||
""" | ||
inputs = self._transform_inputs(inputs) | ||
|
||
aspp_outs = [ | ||
resize( | ||
self.image_pool(inputs), | ||
size=inputs.size()[2:], | ||
mode='bilinear', | ||
align_corners=self.align_corners) | ||
] | ||
|
||
aspp_outs.extend(self.aspp_modules(inputs)) | ||
aspp_outs = torch.cat(aspp_outs, dim=1) | ||
feats = self.bottleneck(aspp_outs) | ||
|
@@ -202,6 +207,7 @@ def forward(self, inputs): | |
|
||
return output | ||
|
||
@TERRATORCH_DECODER_REGISTRY.register | ||
class ASPPSegmentationHead(ASPPHead): | ||
"""Rethinking Atrous Convolution for Semantic Image Segmentation. | ||
|
||
|
@@ -255,6 +261,7 @@ def forward(self, inputs): | |
|
||
return output | ||
|
||
@TERRATORCH_DECODER_REGISTRY.register | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually for line 276 but I can only review changed lines The way that To resolve this I'd suggest
Note this only works when input_transform is None. I currently only care about this use case wherein we select one item from the channel list and this corresponds to the item we pick from the dec_outs. Though if you want to make it more generic mmseg has a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I'll try to adapt and include it in the ASPP module. |
||
class ASPPRegressionHead(ASPPHead): | ||
"""Rethinking Atrous Convolution for regression. | ||
|
||
|
@@ -293,7 +300,6 @@ def __init__(self, dilations:list | tuple =(1, 6, 12, 18), | |
def regression_head(self, features): | ||
|
||
"""PixelWise regression""" | ||
|
||
if self.dropout is not None: | ||
features = self.dropout(features) | ||
output = self.conv_reg(features) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is supposed to represent the out_channels per dec_outs item then it's not accurate, it should be:
For out_channels = 32, num_stages = 5, self.out_channels = [512, 256, 128, 64, 32] which corresponds to the channel of items in dec_outs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Just filling space to see things running.