Skip to content

Tried to use timm-3d encoders, but #8

@KeesariVigneshwarReddy

Description

@KeesariVigneshwarReddy

Program

import segmentation_models_pytorch_3d as smp
import torch

encoder_name = 'tu-maxvit_base_tf_224.in21k'
model = smp.Unet(
    encoder_name=encoder_name,
    encoder_weights=None,
    in_channels=3,
    classes=1,
)
model(torch.rand(4,3,96,96,96)).shape

Error logs

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_31/4109723633.py in <cell line: 0>()
----> 1 model(torch.rand(4,3,96,96,96)).shape

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/segmentation_models_pytorch_3d/base/model.py in forward(self, x)
     46         self.check_input_shape(x)
     47 
---> 48         features = self.encoder(x)
     49         decoder_output = self.decoder(*features)
     50 

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/segmentation_models_pytorch_3d/encoders/timm_universal.py in forward(self, x)
     28 
     29     def forward(self, x):
---> 30         features = self.model(x)
     31         features = [
     32             x,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/timm_3d/models/_features.py in forward(self, x)
    280 
    281     def forward(self, x) -> (List[torch.Tensor]):
--> 282         return list(self._collect(x).values())
    283 
    284 

/usr/local/lib/python3.11/dist-packages/timm_3d/models/_features.py in _collect(self, x)
    234                 x = module(x) if first_or_last_module else checkpoint(module, x)
    235             else:
--> 236                 x = module(x)
    237 
    238             if name in self.return_layers:

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in forward(self, x)
   1097             x = checkpoint_seq(self.blocks, x)
   1098         else:
-> 1099             x = self.blocks(x)
   1100         return x
   1101 

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/container.py in forward(self, input)
    248     def forward(self, input):
    249         for module in self:
--> 250             input = module(input)
    251         return input
    252 

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in forward(self, x)
    987             x = x.permute(0, 2, 3, 4, 1)  # to NHWDC (channels-last)
    988         if self.attn_block is not None:
--> 989             x = self.attn_block(x)
    990         x = self.attn_grid(x)
    991         if not self.nchwd_attn:

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in forward(self, x)
    765         tmp = self.norm1(x)
    766         # print("!K!", tmp.shape)
--> 767         tmp2 = self._partition_attn(tmp)
    768         # print("!L!", tmp2.shape)
    769         x = x + self.drop_path1(self.ls1(tmp2))

/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in _partition_attn(self, x)
    746         if self.partition_block:
    747             # print('W part', img_size)
--> 748             partitioned = window_partition(x, self.partition_size)
    749             # print(partitioned.shape)
    750         else:

/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in window_partition(x, window_size)
    649 def window_partition(x, window_size: List[int]):
    650     B, H, W, D, C = x.shape
--> 651     _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
    652     _assert(W % window_size[1] == 0, f'height ({W}) must be divisible by window ({window_size[1]})')
    653     _assert(D % window_size[2] == 0, f'height ({D}) must be divisible by window ({window_size[2]})')

/usr/local/lib/python3.11/dist-packages/torch/__init__.py in _assert(condition, message)
   2038             _assert, (condition,), condition, message
   2039         )
-> 2040     assert condition, message
   2041 
   2042 

AssertionError: height (3) must be divisible by window (2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions