Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trying to solve issues with UNet and ASPPHeads. #457

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

Joao-L-S-Almeida
Copy link
Member

Alternative to #456.

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
@Joao-L-S-Almeida Joao-L-S-Almeida self-assigned this Feb 25, 2025
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
@Joao-L-S-Almeida Joao-L-S-Almeida marked this pull request as ready for review February 25, 2025 20:06
@Joao-L-S-Almeida Joao-L-S-Almeida changed the title [WiP] New/fix/unet Trying to solve issues with UNet and ASPPHeads. Feb 25, 2025
@@ -149,7 +152,7 @@ def __init__(self,
self.strides = strides
self.downsamples = downsamples
self.norm_eval = norm_eval
self.base_channels = base_channels
self.out_channels = num_stages*[out_channels]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is supposed to represent the out_channels per dec_outs item then it's not accurate, it should be:

self.out_channels = [out_channels * 2**i for i in reversed(range(num_stages))]

For out_channels = 32, num_stages = 5, self.out_channels = [512, 256, 128, 64, 32] which corresponds to the channel of items in dec_outs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Just filling space to see things running.

@@ -292,7 +292,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k: batch[k] for k in other_keys}

self.eval()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into this a bit and setting the whole model to eval also affects the dropout layers, so I don't think we can do this

A better way to just set batch norm layers to eval is mentioned in this post. But, more importantly I don't think we should print a warning about using BatchNorm when batch size is 1 because it is also just a bad idea to do so

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that doesn't make much sense to use batch_size=1 for models with batch normalization layers, even if we add a kind of workaround to make it run.
Maybe the best solution is to include a global check, for any model in terratorch, which will raise and Exception for this cases, indicating the user to adopt batch_size > 1 when BatchNorm is used.
What do you think, @romeokienzler ?

@@ -255,6 +261,7 @@ def forward(self, inputs):

return output

@TERRATORCH_DECODER_REGISTRY.register
Copy link
Collaborator

@singhshraddha singhshraddha Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually for line 276 but I can only review changed lines

The way that ASPPRegressionHead is setup doesn't comply with the way encoder_decoder_factory.py calls decoder build. Since, the build expects that the first arguments to be channel_list and we have dilations instead. So, the dilations are set incorrectly and I cannot set dilations using a config

To resolve this I'd suggest

  • not accepting in_channels as input parameter
  • setting channel_list as the first input parameter
  • Using in_index with channel_list to set the in_channels
class ASPPRegressionHead(ASPPHead):

    def __init__(
        self, 
        channel_list,
        dilations:list | tuple =(1, 6, 12, 18), 
        channels:int=None,
        out_channels:int=1,
        align_corners=False,
        head_dropout_ratio:float=0.3,
        input_transform: str = None,
        in_index: int = -1,
        **kwargs
):

   self.in_channels = channel_list[in_index]

Note this only works when input_transform is None. I currently only care about this use case wherein we select one item from the channel list and this corresponds to the item we pick from the dec_outs. Though if you want to make it more generic mmseg has a BaseDecodeHead._init_inputs function that deals with different types of channel_list, input_transforms etc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I'll try to adapt and include it in the ASPP module.

@@ -243,7 +243,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
other_keys = batch.keys() - {"image", "mask", "filename"}

rest = {k: batch[k] for k in other_keys}

self.eval()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the review in regression_tasks.py for setting model.eval()

Copy link
Member Author

@Joao-L-S-Almeida Joao-L-S-Almeida Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed all these calls. I believe the best approach it to block the usage of batch_size < 2 when BatchNorm is involved. I'll do it in another PR.

Signed-off-by: Joao Lucas de Sousa Almeida <[email protected]>
…t have a role here)

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: Joao Lucas de Sousa Almeida <[email protected]>
Copy link
Collaborator

@singhshraddha singhshraddha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine for my use case. I'll run a model next week using this branch thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants