-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Description
Affects most of the network architectures.
In [1]: import segmentation_models_pytorch_3d as smp
In [2]: model = smp.Unet(
...: encoder_name="timm-skresnet34",
...: encoder_weights=None,
...: in_channels=1,
...: classes=3,
...: activation=None,
...: ),
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[2], line 1
----> 1 model = smp.Unet(
2 encoder_name="timm-skresnet34",
3 encoder_weights=None,
4 in_channels=1,
5 classes=3,
6 activation=None,
7 ),
File ~/anaconda3/envs/czii_2024/lib/python3.12/site-packages/segmentation_models_pytorch_3d/decoders/unet/model.py:72, in Unet.__init__(self, encoder_name, encoder_depth, encoder_weights, decoder_use_batchnorm, decoder_channels, decoder_attention_type, in_channels, classes, activation, aux_params, strides)
56 def __init__(
57 self,
58 encoder_name: str = "resnet34",
(...)
68 strides=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2))
69 ):
70 super().__init__()
---> 72 self.encoder = get_encoder(
73 encoder_name,
74 in_channels=in_channels,
75 depth=encoder_depth,
76 weights=encoder_weights,
77 strides=strides,
78 )
80 self.decoder = UnetDecoder(
81 encoder_channels=self.encoder.out_channels,
82 decoder_channels=decoder_channels,
(...)
87 strides=strides,
88 )
90 self.segmentation_head = SegmentationHead(
91 in_channels=decoder_channels[-1],
92 out_channels=classes,
93 activation=activation,
94 kernel_size=3,
95 )
File ~/anaconda3/envs/czii_2024/lib/python3.12/site-packages/segmentation_models_pytorch_3d/encoders/__init__.py:73, in get_encoder(name, in_channels, depth, weights, output_stride, strides, **kwargs)
71 params.update(depth=depth)
72 params.update(strides=strides)
---> 73 encoder = Encoder(**params)
75 if weights is not None:
76 try:
File ~/anaconda3/envs/czii_2024/lib/python3.12/site-packages/segmentation_models_pytorch_3d/encoders/timm_sknet.py:9, in SkNetEncoder.__init__(self, out_channels, depth, **kwargs)
8 def __init__(self, out_channels, depth=5, **kwargs):
----> 9 super().__init__(**kwargs)
10 self._depth = depth
11 self._out_channels = out_channels
TypeError: ResNet.__init__() got an unexpected keyword argument 'strides'
Metadata
Metadata
Assignees
Labels
No labels