-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Description
Program
import segmentation_models_pytorch_3d as smp
import torch
encoder_name = 'tu-maxvit_base_tf_224.in21k'
model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=None,
in_channels=3,
classes=1,
)
model(torch.rand(4,3,96,96,96)).shape
Error logs
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
/tmp/ipykernel_31/4109723633.py in <cell line: 0>()
----> 1 model(torch.rand(4,3,96,96,96)).shape
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.11/dist-packages/segmentation_models_pytorch_3d/base/model.py in forward(self, x)
46 self.check_input_shape(x)
47
---> 48 features = self.encoder(x)
49 decoder_output = self.decoder(*features)
50
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.11/dist-packages/segmentation_models_pytorch_3d/encoders/timm_universal.py in forward(self, x)
28
29 def forward(self, x):
---> 30 features = self.model(x)
31 features = [
32 x,
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.11/dist-packages/timm_3d/models/_features.py in forward(self, x)
280
281 def forward(self, x) -> (List[torch.Tensor]):
--> 282 return list(self._collect(x).values())
283
284
/usr/local/lib/python3.11/dist-packages/timm_3d/models/_features.py in _collect(self, x)
234 x = module(x) if first_or_last_module else checkpoint(module, x)
235 else:
--> 236 x = module(x)
237
238 if name in self.return_layers:
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in forward(self, x)
1097 x = checkpoint_seq(self.blocks, x)
1098 else:
-> 1099 x = self.blocks(x)
1100 return x
1101
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/container.py in forward(self, input)
248 def forward(self, input):
249 for module in self:
--> 250 input = module(input)
251 return input
252
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in forward(self, x)
987 x = x.permute(0, 2, 3, 4, 1) # to NHWDC (channels-last)
988 if self.attn_block is not None:
--> 989 x = self.attn_block(x)
990 x = self.attn_grid(x)
991 if not self.nchwd_attn:
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in forward(self, x)
765 tmp = self.norm1(x)
766 # print("!K!", tmp.shape)
--> 767 tmp2 = self._partition_attn(tmp)
768 # print("!L!", tmp2.shape)
769 x = x + self.drop_path1(self.ls1(tmp2))
/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in _partition_attn(self, x)
746 if self.partition_block:
747 # print('W part', img_size)
--> 748 partitioned = window_partition(x, self.partition_size)
749 # print(partitioned.shape)
750 else:
/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in window_partition(x, window_size)
649 def window_partition(x, window_size: List[int]):
650 B, H, W, D, C = x.shape
--> 651 _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
652 _assert(W % window_size[1] == 0, f'height ({W}) must be divisible by window ({window_size[1]})')
653 _assert(D % window_size[2] == 0, f'height ({D}) must be divisible by window ({window_size[2]})')
/usr/local/lib/python3.11/dist-packages/torch/__init__.py in _assert(condition, message)
2038 _assert, (condition,), condition, message
2039 )
-> 2040 assert condition, message
2041
2042
AssertionError: height (3) must be divisible by window (2)
Metadata
Metadata
Assignees
Labels
No labels