From 363a361cec50c703a009257d9da121cbbedf01a8 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 8 Dec 2024 05:07:10 +0800 Subject: [PATCH 01/14] Update timm_universal.py --- .../encoders/timm_universal.py | 140 ++++++++++++++++-- 1 file changed, 131 insertions(+), 9 deletions(-) diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index eb008221..6f87fe05 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -1,10 +1,60 @@ +""" +TimmUniversalEncoder provides a unified feature extraction interface built on the +`timm` library, supporting various backbone architectures, including traditional +CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy +(e.g., Swin Transformer, ConvNeXt). + +This encoder produces standardized multi-level feature maps, facilitating integration +with semantic segmentation tasks. It allows configuring the number of feature extraction +stages (`depth`) and adjusting `output_stride` when supported. + +Key Features: +- Flexible model selection through `timm.create_model`. +- A unified interface that outputs consistent, multi-level features even if the + underlying model differs in its feature hierarchy. +- Automatic alignment: If a model lacks certain early-stage features (for example, + modern architectures that start from a 1/4 scale rather than 1/2 scale), the encoder + inserts dummy features to maintain consistency with traditional CNN structures. +- Easy access to channel information: Use the `out_channels` property to retrieve + the number of channels at each feature stage. + +Feature Scale Differences: +- Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16, + and 1/32 scales. +- Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often + start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale + feature. TimmUniversalEncoder compensates for this omission to ensure a unified + multi-stage output. + +Notes: +- Not all models support modifying `output_stride` (especially transformer-based or + transformer-like models). +- Certain models (e.g., TResNet, DLA) require special handling to ensure correct + feature indexing. +- Most `timm` models output features in (B, C, H, W) format. However, some + (e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is + currently unsupported. +""" + from typing import Any import timm +import torch import torch.nn as nn class TimmUniversalEncoder(nn.Module): + """ + A universal encoder built on the `timm` library, designed to adapt to a wide variety of + model architectures, including both traditional CNNs and those that follow a + transformer-like hierarchy. + + Features: + - Supports flexible depth and output stride for feature extraction. + - Automatically adjusts to input/output channel structures based on the model type. + - Compatible with both convolutional and transformer-like encoders. + """ + def __init__( self, name: str, @@ -14,7 +64,19 @@ def __init__( output_stride: int = 32, **kwargs: dict[str, Any], ): + """ + Initialize the encoder. + + Args: + name (str): Name of the model to be loaded from the `timm` library. + pretrained (bool): If True, loads pretrained weights. + in_channels (int): Number of input channels (default: 3 for RGB). + depth (int): Number of feature extraction stages (default: 5). + output_stride (int): Desired output stride (default: 32). + **kwargs: Additional keyword arguments for `timm.create_model`. + """ super().__init__() + common_kwargs = dict( in_chans=in_channels, features_only=True, @@ -23,30 +85,90 @@ def __init__( out_indices=tuple(range(depth)), ) - # not all models support output stride argument, drop it by default if output_stride == 32: common_kwargs.pop("output_stride") - self.model = timm.create_model( - name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) - ) + # Load a preliminary model to determine its feature hierarchy structure. + self.model = timm.create_model(name, features_only=True) + + # Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale) + # rather than a traditional CNN hierarchy (starting at 1/2 scale). + if len(self.model.feature_info.channels()) == 5: + # This indicates a traditional hierarchy: (1/2, 1/4, 1/8, 1/16, 1/32) + self._is_transformer_style = False + else: + # This indicates a transformer-like hierarchy: (1/4, 1/8, 1/16, 1/32) + self._is_transformer_style = True + + if self._is_transformer_style: + if "tresnet" in name: + # 'tresnet' models start feature extraction at stage 1, + # so out_indices=(1, 2, 3, 4) for depth=5. + common_kwargs["out_indices"] = tuple(range(1, depth)) + else: + # Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5. + common_kwargs["out_indices"] = tuple(range(depth - 1)) + + self.model = timm.create_model( + name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) + ) + # Add a dummy output channel (0) to align with traditional encoder structures. + self._out_channels = ( + [in_channels] + [0] + self.model.feature_info.channels() + ) + else: + if "dla" in name: + # For 'dla' models, out_indices starts at 0 and matches the input size. + kwargs["out_indices"] = tuple(range(1, depth + 1)) + + self.model = timm.create_model( + name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) + ) + self._out_channels = [in_channels] + self.model.feature_info.channels() self._in_channels = in_channels - self._out_channels = [in_channels] + self.model.feature_info.channels() self._depth = depth self._output_stride = output_stride - def forward(self, x): + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Pass the input through the encoder and return extracted features. + + Args: + x (torch.Tensor): Input tensor of shape (B, C, H, W). + + Returns: + List[torch.Tensor]: A list of feature maps extracted at various scales. + """ features = self.model(x) - features = [x] + features + + if self._is_transformer_style: + # Models using a transformer-like hierarchy may not generate + # all expected feature maps. Insert a dummy feature map to ensure + # compatibility with decoders expecting a 5-level pyramid. + B, _, H, W = x.shape + dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device) + features = [x] + [dummy] + features + else: + features = [x] + features + return features @property - def out_channels(self): + def out_channels(self) -> list[int]: + """ + Returns: + List[int]: A list of output channels for each stage of the encoder, + including the input channels at the first stage. + """ return self._out_channels @property - def output_stride(self): + def output_stride(self) -> int: + """ + Returns: + int: The effective output stride of the encoder, considering the depth. + """ return min(self._output_stride, 2**self._depth) From f07e10782cc6746a551bd9e171eab5bba0f44ae3 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 8 Dec 2024 05:10:42 +0800 Subject: [PATCH 02/14] Fix ruff style and typing --- .../encoders/timm_universal.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 6f87fe05..e3b4ac17 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -1,7 +1,7 @@ """ TimmUniversalEncoder provides a unified feature extraction interface built on the `timm` library, supporting various backbone architectures, including traditional -CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy +CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy (e.g., Swin Transformer, ConvNeXt). This encoder produces standardized multi-level feature maps, facilitating integration @@ -22,16 +22,16 @@ - Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16, and 1/32 scales. - Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often - start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale + start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale feature. TimmUniversalEncoder compensates for this omission to ensure a unified multi-stage output. Notes: -- Not all models support modifying `output_stride` (especially transformer-based or +- Not all models support modifying `output_stride` (especially transformer-based or transformer-like models). - Certain models (e.g., TResNet, DLA) require special handling to ensure correct feature indexing. -- Most `timm` models output features in (B, C, H, W) format. However, some +- Most `timm` models output features in (B, C, H, W) format. However, some (e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is currently unsupported. """ @@ -46,7 +46,7 @@ class TimmUniversalEncoder(nn.Module): """ A universal encoder built on the `timm` library, designed to adapt to a wide variety of - model architectures, including both traditional CNNs and those that follow a + model architectures, including both traditional CNNs and those that follow a transformer-like hierarchy. Features: @@ -94,10 +94,8 @@ def __init__( # Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale) # rather than a traditional CNN hierarchy (starting at 1/2 scale). if len(self.model.feature_info.channels()) == 5: - # This indicates a traditional hierarchy: (1/2, 1/4, 1/8, 1/16, 1/32) self._is_transformer_style = False else: - # This indicates a transformer-like hierarchy: (1/4, 1/8, 1/16, 1/32) self._is_transformer_style = True if self._is_transformer_style: @@ -138,7 +136,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: x (torch.Tensor): Input tensor of shape (B, C, H, W). Returns: - List[torch.Tensor]: A list of feature maps extracted at various scales. + list[torch.Tensor]: A list of feature maps extracted at various scales. """ features = self.model(x) @@ -158,7 +156,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: def out_channels(self) -> list[int]: """ Returns: - List[int]: A list of output channels for each stage of the encoder, + list[int]: A list of output channels for each stage of the encoder, including the input channels at the first stage. """ return self._out_channels From ea725dd5914a0cddc0a35fa30e4a7cb40aa08e4c Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 8 Dec 2024 06:00:22 +0800 Subject: [PATCH 03/14] Update encoders_timm.rst --- docs/encoders_timm.rst | 700 ++++++++++++++++++++++++++++++++++------- 1 file changed, 585 insertions(+), 115 deletions(-) diff --git a/docs/encoders_timm.rst b/docs/encoders_timm.rst index 26a18a64..903f99cd 100644 --- a/docs/encoders_timm.rst +++ b/docs/encoders_timm.rst @@ -1,34 +1,29 @@ 🎯 Timm Encoders -~~~~~~~~~~~~~~~~ +================ Pytorch Image Models (a.k.a. timm) has a lot of pretrained models and interface which allows using these models as encoders in smp, however, not all models are supported - not all transformer models have ``features_only`` functionality implemented that is required for encoder - some models have inappropriate strides + - some models (such as certain Swin/SwinV2 variants or those outputting in (B, H, W, C) format) are currently not supported Below is a table of suitable encoders (for DeepLabV3, DeepLabV3+, and PAN dilation support is needed also) -Total number of encoders: 549 +Total number of encoders: 761 (579+182) .. note:: To use following encoders you have to add prefix ``tu-``, e.g. ``tu-adv_inception_v3`` +Tranditional-Style Models +~~~~~~~~~~~~~~~~~~~~~~~~~ + +These models typically produce feature maps at the following downsampling scales relative to the input resolution: 1/2, 1/4, 1/8, 1/16, and 1/32 +----------------------------------+------------------+ | Encoder name | Support dilation | +==================================+==================+ -| SelecSls42 | | -+----------------------------------+------------------+ -| SelecSls42b | | -+----------------------------------+------------------+ -| SelecSls60 | | -+----------------------------------+------------------+ -| SelecSls60b | | -+----------------------------------+------------------+ -| SelecSls84 | | -+----------------------------------+------------------+ | bat_resnext26ts | ✅ | +----------------------------------+------------------+ | botnet26t_256 | ✅ | @@ -125,14 +120,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | densenetblur121d | | +----------------------------------+------------------+ -| dla102 | | -+----------------------------------+------------------+ -| dla102x | | -+----------------------------------+------------------+ -| dla102x2 | | -+----------------------------------+------------------+ -| dla169 | | -+----------------------------------+------------------+ | dla34 | | +----------------------------------+------------------+ | dla46_c | | @@ -149,6 +136,14 @@ Total number of encoders: 549 +----------------------------------+------------------+ | dla60x_c | | +----------------------------------+------------------+ +| dla102 | | ++----------------------------------+------------------+ +| dla102x | | ++----------------------------------+------------------+ +| dla102x2 | | ++----------------------------------+------------------+ +| dla169 | | ++----------------------------------+------------------+ | dm_nfnet_f0 | ✅ | +----------------------------------+------------------+ | dm_nfnet_f1 | ✅ | @@ -163,10 +158,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | dm_nfnet_f6 | ✅ | +----------------------------------+------------------+ -| dpn107 | | -+----------------------------------+------------------+ -| dpn131 | | -+----------------------------------+------------------+ | dpn48b | | +----------------------------------+------------------+ | dpn68 | | @@ -177,6 +168,10 @@ Total number of encoders: 549 +----------------------------------+------------------+ | dpn98 | | +----------------------------------+------------------+ +| dpn107 | | ++----------------------------------+------------------+ +| dpn131 | | ++----------------------------------+------------------+ | eca_botnext26ts_256 | ✅ | +----------------------------------+------------------+ | eca_halonext26ts | ✅ | @@ -233,8 +228,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | efficientnet_b2_pruned | ✅ | +----------------------------------+------------------+ -| efficientnet_b2a | ✅ | -+----------------------------------+------------------+ | efficientnet_b3 | ✅ | +----------------------------------+------------------+ | efficientnet_b3_g8_gn | ✅ | @@ -243,8 +236,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | efficientnet_b3_pruned | ✅ | +----------------------------------+------------------+ -| efficientnet_b3a | ✅ | -+----------------------------------+------------------+ | efficientnet_b4 | ✅ | +----------------------------------+------------------+ | efficientnet_b5 | ✅ | @@ -255,6 +246,8 @@ Total number of encoders: 549 +----------------------------------+------------------+ | efficientnet_b8 | ✅ | +----------------------------------+------------------+ +| efficientnet_blur_b0 | ✅ | ++----------------------------------+------------------+ | efficientnet_cc_b0_4e | ✅ | +----------------------------------+------------------+ | efficientnet_cc_b0_8e | ✅ | @@ -341,6 +334,12 @@ Total number of encoders: 549 +----------------------------------+------------------+ | ghostnet_130 | | +----------------------------------+------------------+ +| ghostnetv2_050 | | ++----------------------------------+------------------+ +| ghostnetv2_100 | | ++----------------------------------+------------------+ +| ghostnetv2_130 | | ++----------------------------------+------------------+ | halo2botnet50ts_256 | ✅ | +----------------------------------+------------------+ | halonet26t | ✅ | @@ -385,12 +384,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | hrnet_w64 | | +----------------------------------+------------------+ -| inception_resnet_v2 | | -+----------------------------------+------------------+ -| inception_v3 | | -+----------------------------------+------------------+ -| inception_v4 | | -+----------------------------------+------------------+ | lambda_resnet26rpt_256 | ✅ | +----------------------------------+------------------+ | lambda_resnet26t | ✅ | @@ -411,23 +404,21 @@ Total number of encoders: 549 +----------------------------------+------------------+ | legacy_senet154 | | +----------------------------------+------------------+ -| legacy_seresnet101 | | -+----------------------------------+------------------+ -| legacy_seresnet152 | | -+----------------------------------+------------------+ | legacy_seresnet18 | | +----------------------------------+------------------+ | legacy_seresnet34 | | +----------------------------------+------------------+ | legacy_seresnet50 | | +----------------------------------+------------------+ -| legacy_seresnext101_32x4d | | +| legacy_seresnet101 | | ++----------------------------------+------------------+ +| legacy_seresnet152 | | +----------------------------------+------------------+ | legacy_seresnext26_32x4d | | +----------------------------------+------------------+ | legacy_seresnext50_32x4d | | +----------------------------------+------------------+ -| legacy_xception | | +| legacy_seresnext101_32x4d | | +----------------------------------+------------------+ | maxvit_base_tf_224 | | +----------------------------------+------------------+ @@ -515,11 +506,23 @@ Total number of encoders: 549 +----------------------------------+------------------+ | mnasnet_140 | ✅ | +----------------------------------+------------------+ -| mnasnet_a1 | ✅ | +| mnasnet_small | ✅ | +----------------------------------+------------------+ -| mnasnet_b1 | ✅ | +| mobilenet_edgetpu_100 | ✅ | +----------------------------------+------------------+ -| mnasnet_small | ✅ | +| mobilenet_edgetpu_v2_l | ✅ | ++----------------------------------+------------------+ +| mobilenet_edgetpu_v2_m | ✅ | ++----------------------------------+------------------+ +| mobilenet_edgetpu_v2_s | ✅ | ++----------------------------------+------------------+ +| mobilenet_edgetpu_v2_xs | ✅ | ++----------------------------------+------------------+ +| mobilenetv1_100 | ✅ | ++----------------------------------+------------------+ +| mobilenetv1_100h | ✅ | ++----------------------------------+------------------+ +| mobilenetv1_125 | ✅ | +----------------------------------+------------------+ | mobilenetv2_035 | ✅ | +----------------------------------+------------------+ @@ -539,6 +542,8 @@ Total number of encoders: 549 +----------------------------------+------------------+ | mobilenetv3_large_100 | ✅ | +----------------------------------+------------------+ +| mobilenetv3_large_150d | ✅ | ++----------------------------------+------------------+ | mobilenetv3_rw | ✅ | +----------------------------------+------------------+ | mobilenetv3_small_050 | ✅ | @@ -547,6 +552,40 @@ Total number of encoders: 549 +----------------------------------+------------------+ | mobilenetv3_small_100 | ✅ | +----------------------------------+------------------+ +| mobilenetv4_conv_aa_large | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_aa_medium | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_blur_medium | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_large | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_medium | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_small | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_small_035 | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_small_050 | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_hybrid_large | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_hybrid_large_075 | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_hybrid_medium | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_hybrid_medium_075 | ✅ | ++----------------------------------+------------------+ +| mobileone_s0 | ✅ | ++----------------------------------+------------------+ +| mobileone_s1 | ✅ | ++----------------------------------+------------------+ +| mobileone_s2 | ✅ | ++----------------------------------+------------------+ +| mobileone_s3 | ✅ | ++----------------------------------+------------------+ +| mobileone_s4 | ✅ | ++----------------------------------+------------------+ | mobilevit_s | ✅ | +----------------------------------+------------------+ | mobilevit_xs | ✅ | @@ -567,14 +606,12 @@ Total number of encoders: 549 +----------------------------------+------------------+ | mobilevitv2_200 | ✅ | +----------------------------------+------------------+ -| nasnetalarge | | -+----------------------------------+------------------+ -| nf_ecaresnet101 | ✅ | -+----------------------------------+------------------+ | nf_ecaresnet26 | ✅ | +----------------------------------+------------------+ | nf_ecaresnet50 | ✅ | +----------------------------------+------------------+ +| nf_ecaresnet101 | ✅ | ++----------------------------------+------------------+ | nf_regnet_b0 | ✅ | +----------------------------------+------------------+ | nf_regnet_b1 | ✅ | @@ -587,18 +624,18 @@ Total number of encoders: 549 +----------------------------------+------------------+ | nf_regnet_b5 | ✅ | +----------------------------------+------------------+ -| nf_resnet101 | ✅ | -+----------------------------------+------------------+ | nf_resnet26 | ✅ | +----------------------------------+------------------+ | nf_resnet50 | ✅ | +----------------------------------+------------------+ -| nf_seresnet101 | ✅ | +| nf_resnet101 | ✅ | +----------------------------------+------------------+ | nf_seresnet26 | ✅ | +----------------------------------+------------------+ | nf_seresnet50 | ✅ | +----------------------------------+------------------+ +| nf_seresnet101 | ✅ | ++----------------------------------+------------------+ | nfnet_f0 | ✅ | +----------------------------------+------------------+ | nfnet_f1 | ✅ | @@ -617,8 +654,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | nfnet_l0 | ✅ | +----------------------------------+------------------+ -| pnasnet5large | | -+----------------------------------+------------------+ | regnetv_040 | ✅ | +----------------------------------+------------------+ | regnetv_064 | ✅ | @@ -675,10 +710,10 @@ Total number of encoders: 549 +----------------------------------+------------------+ | regnety_120 | ✅ | +----------------------------------+------------------+ -| regnety_1280 | ✅ | -+----------------------------------+------------------+ | regnety_160 | ✅ | +----------------------------------+------------------+ +| regnety_1280 | ✅ | ++----------------------------------+------------------+ | regnety_2560 | ✅ | +----------------------------------+------------------+ | regnety_320 | ✅ | @@ -707,6 +742,26 @@ Total number of encoders: 549 +----------------------------------+------------------+ | regnetz_e8 | ✅ | +----------------------------------+------------------+ +| repghostnet_050 | | ++----------------------------------+------------------+ +| repghostnet_058 | | ++----------------------------------+------------------+ +| repghostnet_080 | | ++----------------------------------+------------------+ +| repghostnet_100 | | ++----------------------------------+------------------+ +| repghostnet_111 | | ++----------------------------------+------------------+ +| repghostnet_130 | | ++----------------------------------+------------------+ +| repghostnet_150 | | ++----------------------------------+------------------+ +| repghostnet_200 | | ++----------------------------------+------------------+ +| repvgg_a0 | ✅ | ++----------------------------------+------------------+ +| repvgg_a1 | ✅ | ++----------------------------------+------------------+ | repvgg_a2 | ✅ | +----------------------------------+------------------+ | repvgg_b0 | ✅ | @@ -723,9 +778,7 @@ Total number of encoders: 549 +----------------------------------+------------------+ | repvgg_b3g4 | ✅ | +----------------------------------+------------------+ -| res2net101_26w_4s | ✅ | -+----------------------------------+------------------+ -| res2net101d | ✅ | +| repvgg_d2se | ✅ | +----------------------------------+------------------+ | res2net50_14w_8s | ✅ | +----------------------------------+------------------+ @@ -739,15 +792,13 @@ Total number of encoders: 549 +----------------------------------+------------------+ | res2net50d | ✅ | +----------------------------------+------------------+ -| res2next50 | ✅ | -+----------------------------------+------------------+ -| resnest101e | ✅ | +| res2net101_26w_4s | ✅ | +----------------------------------+------------------+ -| resnest14d | ✅ | +| res2net101d | ✅ | +----------------------------------+------------------+ -| resnest200e | ✅ | +| res2next50 | ✅ | +----------------------------------+------------------+ -| resnest269e | ✅ | +| resnest14d | ✅ | +----------------------------------+------------------+ | resnest26d | ✅ | +----------------------------------+------------------+ @@ -757,34 +808,20 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnest50d_4s2x40d | ✅ | +----------------------------------+------------------+ -| resnet101 | ✅ | -+----------------------------------+------------------+ -| resnet101c | ✅ | +| resnest101e | ✅ | +----------------------------------+------------------+ -| resnet101d | ✅ | +| resnest200e | ✅ | +----------------------------------+------------------+ -| resnet101s | ✅ | +| resnest269e | ✅ | +----------------------------------+------------------+ | resnet10t | ✅ | +----------------------------------+------------------+ | resnet14t | ✅ | +----------------------------------+------------------+ -| resnet152 | ✅ | -+----------------------------------+------------------+ -| resnet152c | ✅ | -+----------------------------------+------------------+ -| resnet152d | ✅ | -+----------------------------------+------------------+ -| resnet152s | ✅ | -+----------------------------------+------------------+ | resnet18 | ✅ | +----------------------------------+------------------+ | resnet18d | ✅ | +----------------------------------+------------------+ -| resnet200 | ✅ | -+----------------------------------+------------------+ -| resnet200d | ✅ | -+----------------------------------+------------------+ | resnet26 | ✅ | +----------------------------------+------------------+ | resnet26d | ✅ | @@ -803,6 +840,12 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnet50_gn | ✅ | +----------------------------------+------------------+ +| resnet50_clip | ✅ | ++----------------------------------+------------------+ +| resnet50_clip_gap | ✅ | ++----------------------------------+------------------+ +| resnet50_mlp | ✅ | ++----------------------------------+------------------+ | resnet50c | ✅ | +----------------------------------+------------------+ | resnet50d | ✅ | @@ -811,11 +854,45 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnet50t | ✅ | +----------------------------------+------------------+ +| resnet50x4_clip | ✅ | ++----------------------------------+------------------+ +| resnet50x4_clip_gap | ✅ | ++----------------------------------+------------------+ +| resnet50x16_clip | ✅ | ++----------------------------------+------------------+ +| resnet50x16_clip_gap | ✅ | ++----------------------------------+------------------+ +| resnet50x64_clip | ✅ | ++----------------------------------+------------------+ +| resnet50x64_clip_gap | ✅ | ++----------------------------------+------------------+ | resnet51q | ✅ | +----------------------------------+------------------+ | resnet61q | ✅ | +----------------------------------+------------------+ -| resnetaa101d | ✅ | +| resnet101 | ✅ | ++----------------------------------+------------------+ +| resnet101_clip | ✅ | ++----------------------------------+------------------+ +| resnet101_clip_gap | ✅ | ++----------------------------------+------------------+ +| resnet101c | ✅ | ++----------------------------------+------------------+ +| resnet101d | ✅ | ++----------------------------------+------------------+ +| resnet101s | ✅ | ++----------------------------------+------------------+ +| resnet152 | ✅ | ++----------------------------------+------------------+ +| resnet152c | ✅ | ++----------------------------------+------------------+ +| resnet152d | ✅ | ++----------------------------------+------------------+ +| resnet152s | ✅ | ++----------------------------------+------------------+ +| resnet200 | ✅ | ++----------------------------------+------------------+ +| resnet200d | ✅ | +----------------------------------+------------------+ | resnetaa34d | ✅ | +----------------------------------+------------------+ @@ -823,7 +900,7 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnetaa50d | ✅ | +----------------------------------+------------------+ -| resnetblur101d | ✅ | +| resnetaa101d | ✅ | +----------------------------------+------------------+ | resnetblur18 | ✅ | +----------------------------------+------------------+ @@ -831,6 +908,10 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnetblur50d | ✅ | +----------------------------------+------------------+ +| resnetblur101d | ✅ | ++----------------------------------+------------------+ +| resnetrs50 | ✅ | ++----------------------------------+------------------+ | resnetrs101 | ✅ | +----------------------------------+------------------+ | resnetrs152 | ✅ | @@ -843,23 +924,13 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnetrs420 | ✅ | +----------------------------------+------------------+ -| resnetrs50 | ✅ | +| resnetv2_18 | ✅ | +----------------------------------+------------------+ -| resnetv2_101 | ✅ | +| resnetv2_18d | ✅ | +----------------------------------+------------------+ -| resnetv2_101d | ✅ | +| resnetv2_34 | ✅ | +----------------------------------+------------------+ -| resnetv2_101x1_bit | ✅ | -+----------------------------------+------------------+ -| resnetv2_101x3_bit | ✅ | -+----------------------------------+------------------+ -| resnetv2_152 | ✅ | -+----------------------------------+------------------+ -| resnetv2_152d | ✅ | -+----------------------------------+------------------+ -| resnetv2_152x2_bit | ✅ | -+----------------------------------+------------------+ -| resnetv2_152x4_bit | ✅ | +| resnetv2_34d | ✅ | +----------------------------------+------------------+ | resnetv2_50 | ✅ | +----------------------------------+------------------+ @@ -877,15 +948,21 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnetv2_50x3_bit | ✅ | +----------------------------------+------------------+ -| resnext101_32x16d | ✅ | +| resnetv2_101 | ✅ | +----------------------------------+------------------+ -| resnext101_32x32d | ✅ | +| resnetv2_101d | ✅ | +----------------------------------+------------------+ -| resnext101_32x4d | ✅ | +| resnetv2_101x1_bit | ✅ | +----------------------------------+------------------+ -| resnext101_32x8d | ✅ | +| resnetv2_101x3_bit | ✅ | +----------------------------------+------------------+ -| resnext101_64x4d | ✅ | +| resnetv2_152 | ✅ | ++----------------------------------+------------------+ +| resnetv2_152d | ✅ | ++----------------------------------+------------------+ +| resnetv2_152x2_bit | ✅ | ++----------------------------------+------------------+ +| resnetv2_152x4_bit | ✅ | +----------------------------------+------------------+ | resnext26ts | ✅ | +----------------------------------+------------------+ @@ -893,6 +970,16 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnext50d_32x4d | ✅ | +----------------------------------+------------------+ +| resnext101_32x4d | ✅ | ++----------------------------------+------------------+ +| resnext101_32x8d | ✅ | ++----------------------------------+------------------+ +| resnext101_32x16d | ✅ | ++----------------------------------+------------------+ +| resnext101_32x32d | ✅ | ++----------------------------------+------------------+ +| resnext101_64x4d | ✅ | ++----------------------------------+------------------+ | rexnet_100 | ✅ | +----------------------------------+------------------+ | rexnet_130 | ✅ | @@ -917,6 +1004,16 @@ Total number of encoders: 549 +----------------------------------+------------------+ | sehalonet33ts | ✅ | +----------------------------------+------------------+ +| selecsls42 | | ++----------------------------------+------------------+ +| selecsls42b | | ++----------------------------------+------------------+ +| selecsls60 | | ++----------------------------------+------------------+ +| selecsls60b | | ++----------------------------------+------------------+ +| selecsls84 | | ++----------------------------------+------------------+ | semnasnet_050 | ✅ | +----------------------------------+------------------+ | semnasnet_075 | ✅ | @@ -927,18 +1024,8 @@ Total number of encoders: 549 +----------------------------------+------------------+ | senet154 | ✅ | +----------------------------------+------------------+ -| seresnet101 | ✅ | -+----------------------------------+------------------+ -| seresnet152 | ✅ | -+----------------------------------+------------------+ -| seresnet152d | ✅ | -+----------------------------------+------------------+ | seresnet18 | ✅ | +----------------------------------+------------------+ -| seresnet200d | ✅ | -+----------------------------------+------------------+ -| seresnet269d | ✅ | -+----------------------------------+------------------+ | seresnet33ts | ✅ | +----------------------------------+------------------+ | seresnet34 | ✅ | @@ -947,6 +1034,16 @@ Total number of encoders: 549 +----------------------------------+------------------+ | seresnet50t | ✅ | +----------------------------------+------------------+ +| seresnet101 | ✅ | ++----------------------------------+------------------+ +| seresnet152 | ✅ | ++----------------------------------+------------------+ +| seresnet152d | ✅ | ++----------------------------------+------------------+ +| seresnet200d | ✅ | ++----------------------------------+------------------+ +| seresnet269d | ✅ | ++----------------------------------+------------------+ | seresnetaa50d | ✅ | +----------------------------------+------------------+ | seresnext101_32x4d | ✅ | @@ -961,14 +1058,14 @@ Total number of encoders: 549 +----------------------------------+------------------+ | seresnext26t_32x4d | ✅ | +----------------------------------+------------------+ -| seresnext26tn_32x4d | ✅ | -+----------------------------------+------------------+ | seresnext26ts | ✅ | +----------------------------------+------------------+ | seresnext50_32x4d | ✅ | +----------------------------------+------------------+ | seresnextaa101d_32x8d | ✅ | +----------------------------------+------------------+ +| seresnextaa201d_32x8d | ✅ | ++----------------------------------+------------------+ | skresnet18 | ✅ | +----------------------------------+------------------+ | skresnet34 | ✅ | @@ -1071,10 +1168,10 @@ Total number of encoders: 549 +----------------------------------+------------------+ | vovnet57a | | +----------------------------------+------------------+ -| wide_resnet101_2 | ✅ | -+----------------------------------+------------------+ | wide_resnet50_2 | ✅ | +----------------------------------+------------------+ +| wide_resnet101_2 | ✅ | ++----------------------------------+------------------+ | xception41 | ✅ | +----------------------------------+------------------+ | xception41p | ✅ | @@ -1086,3 +1183,376 @@ Total number of encoders: 549 | xception71 | ✅ | +----------------------------------+------------------+ +Transformer-style +~~~~~~~~~~~~~~~~~ + +Transformer-style models (e.g., Swin Transformer, ConvNeXt) typically produce feature maps starting at a 1/4 scale, followed by 1/8, 1/16, and 1/32 scales + ++----------------------------------+------------------+ +| Encoder name | Support dilation | ++==================================+==================+ +| caformer_b36 | | ++----------------------------------+------------------+ +| caformer_m36 | | ++----------------------------------+------------------+ +| caformer_s18 | | ++----------------------------------+------------------+ +| caformer_s36 | | ++----------------------------------+------------------+ +| convformer_b36 | | ++----------------------------------+------------------+ +| convformer_m36 | | ++----------------------------------+------------------+ +| convformer_s18 | | ++----------------------------------+------------------+ +| convformer_s36 | | ++----------------------------------+------------------+ +| convnext_atto | ✅ | ++----------------------------------+------------------+ +| convnext_atto_ols | ✅ | ++----------------------------------+------------------+ +| convnext_atto_rms | ✅ | ++----------------------------------+------------------+ +| convnext_base | ✅ | ++----------------------------------+------------------+ +| convnext_femto | ✅ | ++----------------------------------+------------------+ +| convnext_femto_ols | ✅ | ++----------------------------------+------------------+ +| convnext_large | ✅ | ++----------------------------------+------------------+ +| convnext_large_mlp | ✅ | ++----------------------------------+------------------+ +| convnext_nano | ✅ | ++----------------------------------+------------------+ +| convnext_nano_ols | ✅ | ++----------------------------------+------------------+ +| convnext_pico | ✅ | ++----------------------------------+------------------+ +| convnext_pico_ols | ✅ | ++----------------------------------+------------------+ +| convnext_small | ✅ | ++----------------------------------+------------------+ +| convnext_tiny | ✅ | ++----------------------------------+------------------+ +| convnext_tiny_hnf | ✅ | ++----------------------------------+------------------+ +| convnext_xlarge | ✅ | ++----------------------------------+------------------+ +| convnext_xxlarge | ✅ | ++----------------------------------+------------------+ +| convnext_zepto_rms | ✅ | ++----------------------------------+------------------+ +| convnext_zepto_rms_ols | ✅ | ++----------------------------------+------------------+ +| convnextv2_atto | ✅ | ++----------------------------------+------------------+ +| convnextv2_base | ✅ | ++----------------------------------+------------------+ +| convnextv2_femto | ✅ | ++----------------------------------+------------------+ +| convnextv2_huge | ✅ | ++----------------------------------+------------------+ +| convnextv2_large | ✅ | ++----------------------------------+------------------+ +| convnextv2_nano | ✅ | ++----------------------------------+------------------+ +| convnextv2_pico | ✅ | ++----------------------------------+------------------+ +| convnextv2_small | ✅ | ++----------------------------------+------------------+ +| convnextv2_tiny | ✅ | ++----------------------------------+------------------+ +| davit_base | | ++----------------------------------+------------------+ +| davit_base_fl | | ++----------------------------------+------------------+ +| davit_giant | | ++----------------------------------+------------------+ +| davit_huge | | ++----------------------------------+------------------+ +| davit_huge_fl | | ++----------------------------------+------------------+ +| davit_large | | ++----------------------------------+------------------+ +| davit_small | | ++----------------------------------+------------------+ +| davit_tiny | | ++----------------------------------+------------------+ +| edgenext_base | | ++----------------------------------+------------------+ +| edgenext_small | | ++----------------------------------+------------------+ +| edgenext_small_rw | | ++----------------------------------+------------------+ +| edgenext_x_small | | ++----------------------------------+------------------+ +| edgenext_xx_small | | ++----------------------------------+------------------+ +| efficientformer_l1 | | ++----------------------------------+------------------+ +| efficientformer_l3 | | ++----------------------------------+------------------+ +| efficientformer_l7 | | ++----------------------------------+------------------+ +| efficientformerv2_l | | ++----------------------------------+------------------+ +| efficientformerv2_s0 | | ++----------------------------------+------------------+ +| efficientformerv2_s1 | | ++----------------------------------+------------------+ +| efficientformerv2_s2 | | ++----------------------------------+------------------+ +| efficientvit_b0 | | ++----------------------------------+------------------+ +| efficientvit_b1 | | ++----------------------------------+------------------+ +| efficientvit_b2 | | ++----------------------------------+------------------+ +| efficientvit_b3 | | ++----------------------------------+------------------+ +| efficientvit_l1 | | ++----------------------------------+------------------+ +| efficientvit_l2 | | ++----------------------------------+------------------+ +| efficientvit_l3 | | ++----------------------------------+------------------+ +| fastvit_ma36 | | ++----------------------------------+------------------+ +| fastvit_mci0 | | ++----------------------------------+------------------+ +| fastvit_mci1 | | ++----------------------------------+------------------+ +| fastvit_mci2 | | ++----------------------------------+------------------+ +| fastvit_s12 | | ++----------------------------------+------------------+ +| fastvit_sa12 | | ++----------------------------------+------------------+ +| fastvit_sa24 | | ++----------------------------------+------------------+ +| fastvit_sa36 | | ++----------------------------------+------------------+ +| fastvit_t8 | | ++----------------------------------+------------------+ +| fastvit_t12 | | ++----------------------------------+------------------+ +| focalnet_base_lrf | | ++----------------------------------+------------------+ +| focalnet_base_srf | | ++----------------------------------+------------------+ +| focalnet_huge_fl3 | | ++----------------------------------+------------------+ +| focalnet_huge_fl4 | | ++----------------------------------+------------------+ +| focalnet_large_fl3 | | ++----------------------------------+------------------+ +| focalnet_large_fl4 | | ++----------------------------------+------------------+ +| focalnet_small_lrf | | ++----------------------------------+------------------+ +| focalnet_small_srf | | ++----------------------------------+------------------+ +| focalnet_tiny_lrf | | ++----------------------------------+------------------+ +| focalnet_tiny_srf | | ++----------------------------------+------------------+ +| focalnet_xlarge_fl3 | | ++----------------------------------+------------------+ +| focalnet_xlarge_fl4 | | ++----------------------------------+------------------+ +| hgnet_base | | ++----------------------------------+------------------+ +| hgnet_small | | ++----------------------------------+------------------+ +| hgnet_tiny | | ++----------------------------------+------------------+ +| hgnetv2_b0 | | ++----------------------------------+------------------+ +| hgnetv2_b1 | | ++----------------------------------+------------------+ +| hgnetv2_b2 | | ++----------------------------------+------------------+ +| hgnetv2_b3 | | ++----------------------------------+------------------+ +| hgnetv2_b4 | | ++----------------------------------+------------------+ +| hgnetv2_b5 | | ++----------------------------------+------------------+ +| hgnetv2_b6 | | ++----------------------------------+------------------+ +| hiera_base_224 | | ++----------------------------------+------------------+ +| hiera_base_abswin_256 | | ++----------------------------------+------------------+ +| hiera_base_plus_224 | | ++----------------------------------+------------------+ +| hiera_huge_224 | | ++----------------------------------+------------------+ +| hiera_large_224 | | ++----------------------------------+------------------+ +| hiera_small_224 | | ++----------------------------------+------------------+ +| hiera_small_abswin_256 | | ++----------------------------------+------------------+ +| hiera_tiny_224 | | ++----------------------------------+------------------+ +| hieradet_small | | ++----------------------------------+------------------+ +| inception_next_base | | ++----------------------------------+------------------+ +| inception_next_small | | ++----------------------------------+------------------+ +| inception_next_tiny | | ++----------------------------------+------------------+ +| mvitv2_base | | ++----------------------------------+------------------+ +| mvitv2_base_cls | | ++----------------------------------+------------------+ +| mvitv2_huge_cls | | ++----------------------------------+------------------+ +| mvitv2_large | | ++----------------------------------+------------------+ +| mvitv2_large_cls | | ++----------------------------------+------------------+ +| mvitv2_small | | ++----------------------------------+------------------+ +| mvitv2_small_cls | | ++----------------------------------+------------------+ +| mvitv2_tiny | | ++----------------------------------+------------------+ +| nextvit_base | | ++----------------------------------+------------------+ +| nextvit_large | | ++----------------------------------+------------------+ +| nextvit_small | | ++----------------------------------+------------------+ +| poolformer_m36 | | ++----------------------------------+------------------+ +| poolformer_m48 | | ++----------------------------------+------------------+ +| poolformer_s12 | | ++----------------------------------+------------------+ +| poolformer_s24 | | ++----------------------------------+------------------+ +| poolformer_s36 | | ++----------------------------------+------------------+ +| poolformerv2_m36 | | ++----------------------------------+------------------+ +| poolformerv2_m48 | | ++----------------------------------+------------------+ +| poolformerv2_s12 | | ++----------------------------------+------------------+ +| poolformerv2_s24 | | ++----------------------------------+------------------+ +| poolformerv2_s36 | | ++----------------------------------+------------------+ +| pvt_v2_b0 | | ++----------------------------------+------------------+ +| pvt_v2_b1 | | ++----------------------------------+------------------+ +| pvt_v2_b2 | | ++----------------------------------+------------------+ +| pvt_v2_b2_li | | ++----------------------------------+------------------+ +| pvt_v2_b3 | | ++----------------------------------+------------------+ +| pvt_v2_b4 | | ++----------------------------------+------------------+ +| pvt_v2_b5 | | ++----------------------------------+------------------+ +| rdnet_base | | ++----------------------------------+------------------+ +| rdnet_large | | ++----------------------------------+------------------+ +| rdnet_small | | ++----------------------------------+------------------+ +| rdnet_tiny | | ++----------------------------------+------------------+ +| repvit_m0_9 | | ++----------------------------------+------------------+ +| repvit_m1 | | ++----------------------------------+------------------+ +| repvit_m1_0 | | ++----------------------------------+------------------+ +| repvit_m1_1 | | ++----------------------------------+------------------+ +| repvit_m1_5 | | ++----------------------------------+------------------+ +| repvit_m2 | | ++----------------------------------+------------------+ +| repvit_m2_3 | | ++----------------------------------+------------------+ +| repvit_m3 | | ++----------------------------------+------------------+ +| sam2_hiera_base_plus | | ++----------------------------------+------------------+ +| sam2_hiera_large | | ++----------------------------------+------------------+ +| sam2_hiera_small | | ++----------------------------------+------------------+ +| sam2_hiera_tiny | | ++----------------------------------+------------------+ +| swinv2_cr_base_224 | | ++----------------------------------+------------------+ +| swinv2_cr_base_384 | | ++----------------------------------+------------------+ +| swinv2_cr_base_ns_224 | | ++----------------------------------+------------------+ +| swinv2_cr_giant_224 | | ++----------------------------------+------------------+ +| swinv2_cr_giant_384 | | ++----------------------------------+------------------+ +| swinv2_cr_huge_224 | | ++----------------------------------+------------------+ +| swinv2_cr_huge_384 | | ++----------------------------------+------------------+ +| swinv2_cr_large_224 | | ++----------------------------------+------------------+ +| swinv2_cr_large_384 | | ++----------------------------------+------------------+ +| swinv2_cr_small_224 | | ++----------------------------------+------------------+ +| swinv2_cr_small_384 | | ++----------------------------------+------------------+ +| swinv2_cr_small_ns_224 | | ++----------------------------------+------------------+ +| swinv2_cr_small_ns_256 | | ++----------------------------------+------------------+ +| swinv2_cr_tiny_224 | | ++----------------------------------+------------------+ +| swinv2_cr_tiny_384 | | ++----------------------------------+------------------+ +| swinv2_cr_tiny_ns_224 | | ++----------------------------------+------------------+ +| tiny_vit_5m_224 | | ++----------------------------------+------------------+ +| tiny_vit_11m_224 | | ++----------------------------------+------------------+ +| tiny_vit_21m_224 | | ++----------------------------------+------------------+ +| tiny_vit_21m_384 | | ++----------------------------------+------------------+ +| tiny_vit_21m_512 | | ++----------------------------------+------------------+ +| tresnet_l | | ++----------------------------------+------------------+ +| tresnet_m | | ++----------------------------------+------------------+ +| tresnet_v2_l | | ++----------------------------------+------------------+ +| tresnet_xl | | ++----------------------------------+------------------+ +| twins_pcpvt_base | | ++----------------------------------+------------------+ +| twins_pcpvt_large | | ++----------------------------------+------------------+ +| twins_pcpvt_small | | ++----------------------------------+------------------+ +| twins_svt_base | | ++----------------------------------+------------------+ +| twins_svt_large | | ++----------------------------------+------------------+ +| twins_svt_small | | ++----------------------------------+------------------+ + From 029c1901834f9e19ac2c12294f47b123bae570cc Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 8 Dec 2024 06:10:37 +0800 Subject: [PATCH 04/14] Fix typo error --- docs/encoders_timm.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/encoders_timm.rst b/docs/encoders_timm.rst index 903f99cd..fea4e6ce 100644 --- a/docs/encoders_timm.rst +++ b/docs/encoders_timm.rst @@ -16,7 +16,7 @@ Total number of encoders: 761 (579+182) To use following encoders you have to add prefix ``tu-``, e.g. ``tu-adv_inception_v3`` -Tranditional-Style Models +Traditional-Style Models ~~~~~~~~~~~~~~~~~~~~~~~~~ These models typically produce feature maps at the following downsampling scales relative to the input resolution: 1/2, 1/4, 1/8, 1/16, and 1/32 From c58b7cf2ef5b0392d51d70ffdce3c8afcbf00bc7 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 8 Dec 2024 21:59:16 +0800 Subject: [PATCH 05/14] Fix typo error & Update doc --- docs/encoders_timm.rst | 4 ++-- segmentation_models_pytorch/encoders/timm_universal.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/encoders_timm.rst b/docs/encoders_timm.rst index fea4e6ce..37a7ab11 100644 --- a/docs/encoders_timm.rst +++ b/docs/encoders_timm.rst @@ -16,8 +16,8 @@ Total number of encoders: 761 (579+182) To use following encoders you have to add prefix ``tu-``, e.g. ``tu-adv_inception_v3`` -Traditional-Style Models -~~~~~~~~~~~~~~~~~~~~~~~~~ +Traditional-Style +~~~~~~~~~~~~~~~~~ These models typically produce feature maps at the following downsampling scales relative to the input resolution: 1/2, 1/4, 1/8, 1/16, and 1/32 diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index e3b4ac17..2e9ed454 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -117,7 +117,7 @@ def __init__( else: if "dla" in name: # For 'dla' models, out_indices starts at 0 and matches the input size. - kwargs["out_indices"] = tuple(range(1, depth + 1)) + common_kwargs["out_indices"] = tuple(range(1, depth + 1)) self.model = timm.create_model( name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) From eae7e2b10baec3c04725a7615fe5ef8b24a51420 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 8 Dec 2024 22:52:03 +0800 Subject: [PATCH 06/14] Fix typo error --- docs/encoders_timm.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/encoders_timm.rst b/docs/encoders_timm.rst index 37a7ab11..2bd39f96 100644 --- a/docs/encoders_timm.rst +++ b/docs/encoders_timm.rst @@ -1183,7 +1183,7 @@ These models typically produce feature maps at the following downsampling scales | xception71 | ✅ | +----------------------------------+------------------+ -Transformer-style +Transformer-Style ~~~~~~~~~~~~~~~~~ Transformer-style models (e.g., Swin Transformer, ConvNeXt) typically produce feature maps starting at a 1/4 scale, followed by 1/8, 1/16, and 1/32 scales From 4148788f0c1e9aa075d9c70c1699b1df4baa9987 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Wed, 11 Dec 2024 00:38:33 +0800 Subject: [PATCH 07/14] Support channel-last format --- docs/encoders_timm.rst | 799 ++++++++++-------- .../encoders/timm_universal.py | 13 +- 2 files changed, 440 insertions(+), 372 deletions(-) diff --git a/docs/encoders_timm.rst b/docs/encoders_timm.rst index 2bd39f96..4866d6e1 100644 --- a/docs/encoders_timm.rst +++ b/docs/encoders_timm.rst @@ -6,11 +6,10 @@ however, not all models are supported - not all transformer models have ``features_only`` functionality implemented that is required for encoder - some models have inappropriate strides - - some models (such as certain Swin/SwinV2 variants or those outputting in (B, H, W, C) format) are currently not supported Below is a table of suitable encoders (for DeepLabV3, DeepLabV3+, and PAN dilation support is needed also) -Total number of encoders: 761 (579+182) +Total number of encoders: 792 (579+213) .. note:: @@ -1188,371 +1187,433 @@ Transformer-Style Transformer-style models (e.g., Swin Transformer, ConvNeXt) typically produce feature maps starting at a 1/4 scale, followed by 1/8, 1/16, and 1/32 scales -+----------------------------------+------------------+ -| Encoder name | Support dilation | -+==================================+==================+ -| caformer_b36 | | -+----------------------------------+------------------+ -| caformer_m36 | | -+----------------------------------+------------------+ -| caformer_s18 | | -+----------------------------------+------------------+ -| caformer_s36 | | -+----------------------------------+------------------+ -| convformer_b36 | | -+----------------------------------+------------------+ -| convformer_m36 | | -+----------------------------------+------------------+ -| convformer_s18 | | -+----------------------------------+------------------+ -| convformer_s36 | | -+----------------------------------+------------------+ -| convnext_atto | ✅ | -+----------------------------------+------------------+ -| convnext_atto_ols | ✅ | -+----------------------------------+------------------+ -| convnext_atto_rms | ✅ | -+----------------------------------+------------------+ -| convnext_base | ✅ | -+----------------------------------+------------------+ -| convnext_femto | ✅ | -+----------------------------------+------------------+ -| convnext_femto_ols | ✅ | -+----------------------------------+------------------+ -| convnext_large | ✅ | -+----------------------------------+------------------+ -| convnext_large_mlp | ✅ | -+----------------------------------+------------------+ -| convnext_nano | ✅ | -+----------------------------------+------------------+ -| convnext_nano_ols | ✅ | -+----------------------------------+------------------+ -| convnext_pico | ✅ | -+----------------------------------+------------------+ -| convnext_pico_ols | ✅ | -+----------------------------------+------------------+ -| convnext_small | ✅ | -+----------------------------------+------------------+ -| convnext_tiny | ✅ | -+----------------------------------+------------------+ -| convnext_tiny_hnf | ✅ | -+----------------------------------+------------------+ -| convnext_xlarge | ✅ | -+----------------------------------+------------------+ -| convnext_xxlarge | ✅ | -+----------------------------------+------------------+ -| convnext_zepto_rms | ✅ | -+----------------------------------+------------------+ -| convnext_zepto_rms_ols | ✅ | -+----------------------------------+------------------+ -| convnextv2_atto | ✅ | -+----------------------------------+------------------+ -| convnextv2_base | ✅ | -+----------------------------------+------------------+ -| convnextv2_femto | ✅ | -+----------------------------------+------------------+ -| convnextv2_huge | ✅ | -+----------------------------------+------------------+ -| convnextv2_large | ✅ | -+----------------------------------+------------------+ -| convnextv2_nano | ✅ | -+----------------------------------+------------------+ -| convnextv2_pico | ✅ | -+----------------------------------+------------------+ -| convnextv2_small | ✅ | -+----------------------------------+------------------+ -| convnextv2_tiny | ✅ | -+----------------------------------+------------------+ -| davit_base | | -+----------------------------------+------------------+ -| davit_base_fl | | -+----------------------------------+------------------+ -| davit_giant | | -+----------------------------------+------------------+ -| davit_huge | | -+----------------------------------+------------------+ -| davit_huge_fl | | -+----------------------------------+------------------+ -| davit_large | | -+----------------------------------+------------------+ -| davit_small | | -+----------------------------------+------------------+ -| davit_tiny | | -+----------------------------------+------------------+ -| edgenext_base | | -+----------------------------------+------------------+ -| edgenext_small | | -+----------------------------------+------------------+ -| edgenext_small_rw | | -+----------------------------------+------------------+ -| edgenext_x_small | | -+----------------------------------+------------------+ -| edgenext_xx_small | | -+----------------------------------+------------------+ -| efficientformer_l1 | | -+----------------------------------+------------------+ -| efficientformer_l3 | | -+----------------------------------+------------------+ -| efficientformer_l7 | | -+----------------------------------+------------------+ -| efficientformerv2_l | | -+----------------------------------+------------------+ -| efficientformerv2_s0 | | -+----------------------------------+------------------+ -| efficientformerv2_s1 | | -+----------------------------------+------------------+ -| efficientformerv2_s2 | | -+----------------------------------+------------------+ -| efficientvit_b0 | | -+----------------------------------+------------------+ -| efficientvit_b1 | | -+----------------------------------+------------------+ -| efficientvit_b2 | | -+----------------------------------+------------------+ -| efficientvit_b3 | | -+----------------------------------+------------------+ -| efficientvit_l1 | | -+----------------------------------+------------------+ -| efficientvit_l2 | | -+----------------------------------+------------------+ -| efficientvit_l3 | | -+----------------------------------+------------------+ -| fastvit_ma36 | | -+----------------------------------+------------------+ -| fastvit_mci0 | | -+----------------------------------+------------------+ -| fastvit_mci1 | | -+----------------------------------+------------------+ -| fastvit_mci2 | | -+----------------------------------+------------------+ -| fastvit_s12 | | -+----------------------------------+------------------+ -| fastvit_sa12 | | -+----------------------------------+------------------+ -| fastvit_sa24 | | -+----------------------------------+------------------+ -| fastvit_sa36 | | -+----------------------------------+------------------+ -| fastvit_t8 | | -+----------------------------------+------------------+ -| fastvit_t12 | | -+----------------------------------+------------------+ -| focalnet_base_lrf | | -+----------------------------------+------------------+ -| focalnet_base_srf | | -+----------------------------------+------------------+ -| focalnet_huge_fl3 | | -+----------------------------------+------------------+ -| focalnet_huge_fl4 | | -+----------------------------------+------------------+ -| focalnet_large_fl3 | | -+----------------------------------+------------------+ -| focalnet_large_fl4 | | -+----------------------------------+------------------+ -| focalnet_small_lrf | | -+----------------------------------+------------------+ -| focalnet_small_srf | | -+----------------------------------+------------------+ -| focalnet_tiny_lrf | | -+----------------------------------+------------------+ -| focalnet_tiny_srf | | -+----------------------------------+------------------+ -| focalnet_xlarge_fl3 | | -+----------------------------------+------------------+ -| focalnet_xlarge_fl4 | | -+----------------------------------+------------------+ -| hgnet_base | | -+----------------------------------+------------------+ -| hgnet_small | | -+----------------------------------+------------------+ -| hgnet_tiny | | -+----------------------------------+------------------+ -| hgnetv2_b0 | | -+----------------------------------+------------------+ -| hgnetv2_b1 | | -+----------------------------------+------------------+ -| hgnetv2_b2 | | -+----------------------------------+------------------+ -| hgnetv2_b3 | | -+----------------------------------+------------------+ -| hgnetv2_b4 | | -+----------------------------------+------------------+ -| hgnetv2_b5 | | -+----------------------------------+------------------+ -| hgnetv2_b6 | | -+----------------------------------+------------------+ -| hiera_base_224 | | -+----------------------------------+------------------+ -| hiera_base_abswin_256 | | -+----------------------------------+------------------+ -| hiera_base_plus_224 | | -+----------------------------------+------------------+ -| hiera_huge_224 | | -+----------------------------------+------------------+ -| hiera_large_224 | | -+----------------------------------+------------------+ -| hiera_small_224 | | -+----------------------------------+------------------+ -| hiera_small_abswin_256 | | -+----------------------------------+------------------+ -| hiera_tiny_224 | | -+----------------------------------+------------------+ -| hieradet_small | | -+----------------------------------+------------------+ -| inception_next_base | | -+----------------------------------+------------------+ -| inception_next_small | | -+----------------------------------+------------------+ -| inception_next_tiny | | -+----------------------------------+------------------+ -| mvitv2_base | | -+----------------------------------+------------------+ -| mvitv2_base_cls | | -+----------------------------------+------------------+ -| mvitv2_huge_cls | | -+----------------------------------+------------------+ -| mvitv2_large | | -+----------------------------------+------------------+ -| mvitv2_large_cls | | -+----------------------------------+------------------+ -| mvitv2_small | | -+----------------------------------+------------------+ -| mvitv2_small_cls | | -+----------------------------------+------------------+ -| mvitv2_tiny | | -+----------------------------------+------------------+ -| nextvit_base | | -+----------------------------------+------------------+ -| nextvit_large | | -+----------------------------------+------------------+ -| nextvit_small | | -+----------------------------------+------------------+ -| poolformer_m36 | | -+----------------------------------+------------------+ -| poolformer_m48 | | -+----------------------------------+------------------+ -| poolformer_s12 | | -+----------------------------------+------------------+ -| poolformer_s24 | | -+----------------------------------+------------------+ -| poolformer_s36 | | -+----------------------------------+------------------+ -| poolformerv2_m36 | | -+----------------------------------+------------------+ -| poolformerv2_m48 | | -+----------------------------------+------------------+ -| poolformerv2_s12 | | -+----------------------------------+------------------+ -| poolformerv2_s24 | | -+----------------------------------+------------------+ -| poolformerv2_s36 | | -+----------------------------------+------------------+ -| pvt_v2_b0 | | -+----------------------------------+------------------+ -| pvt_v2_b1 | | -+----------------------------------+------------------+ -| pvt_v2_b2 | | -+----------------------------------+------------------+ -| pvt_v2_b2_li | | -+----------------------------------+------------------+ -| pvt_v2_b3 | | -+----------------------------------+------------------+ -| pvt_v2_b4 | | -+----------------------------------+------------------+ -| pvt_v2_b5 | | -+----------------------------------+------------------+ -| rdnet_base | | -+----------------------------------+------------------+ -| rdnet_large | | -+----------------------------------+------------------+ -| rdnet_small | | -+----------------------------------+------------------+ -| rdnet_tiny | | -+----------------------------------+------------------+ -| repvit_m0_9 | | -+----------------------------------+------------------+ -| repvit_m1 | | -+----------------------------------+------------------+ -| repvit_m1_0 | | -+----------------------------------+------------------+ -| repvit_m1_1 | | -+----------------------------------+------------------+ -| repvit_m1_5 | | -+----------------------------------+------------------+ -| repvit_m2 | | -+----------------------------------+------------------+ -| repvit_m2_3 | | -+----------------------------------+------------------+ -| repvit_m3 | | -+----------------------------------+------------------+ -| sam2_hiera_base_plus | | -+----------------------------------+------------------+ -| sam2_hiera_large | | -+----------------------------------+------------------+ -| sam2_hiera_small | | -+----------------------------------+------------------+ -| sam2_hiera_tiny | | -+----------------------------------+------------------+ -| swinv2_cr_base_224 | | -+----------------------------------+------------------+ -| swinv2_cr_base_384 | | -+----------------------------------+------------------+ -| swinv2_cr_base_ns_224 | | -+----------------------------------+------------------+ -| swinv2_cr_giant_224 | | -+----------------------------------+------------------+ -| swinv2_cr_giant_384 | | -+----------------------------------+------------------+ -| swinv2_cr_huge_224 | | -+----------------------------------+------------------+ -| swinv2_cr_huge_384 | | -+----------------------------------+------------------+ -| swinv2_cr_large_224 | | -+----------------------------------+------------------+ -| swinv2_cr_large_384 | | -+----------------------------------+------------------+ -| swinv2_cr_small_224 | | -+----------------------------------+------------------+ -| swinv2_cr_small_384 | | -+----------------------------------+------------------+ -| swinv2_cr_small_ns_224 | | -+----------------------------------+------------------+ -| swinv2_cr_small_ns_256 | | -+----------------------------------+------------------+ -| swinv2_cr_tiny_224 | | -+----------------------------------+------------------+ -| swinv2_cr_tiny_384 | | -+----------------------------------+------------------+ -| swinv2_cr_tiny_ns_224 | | -+----------------------------------+------------------+ -| tiny_vit_5m_224 | | -+----------------------------------+------------------+ -| tiny_vit_11m_224 | | -+----------------------------------+------------------+ -| tiny_vit_21m_224 | | -+----------------------------------+------------------+ -| tiny_vit_21m_384 | | -+----------------------------------+------------------+ -| tiny_vit_21m_512 | | -+----------------------------------+------------------+ -| tresnet_l | | -+----------------------------------+------------------+ -| tresnet_m | | -+----------------------------------+------------------+ -| tresnet_v2_l | | -+----------------------------------+------------------+ -| tresnet_xl | | -+----------------------------------+------------------+ -| twins_pcpvt_base | | -+----------------------------------+------------------+ -| twins_pcpvt_large | | -+----------------------------------+------------------+ -| twins_pcpvt_small | | -+----------------------------------+------------------+ -| twins_svt_base | | -+----------------------------------+------------------+ -| twins_svt_large | | -+----------------------------------+------------------+ -| twins_svt_small | | -+----------------------------------+------------------+ ++------------------------------------+------------------+ +| Encoder name | Support dilation | ++====================================+==================+ +| caformer_b36 | | ++------------------------------------+------------------+ +| caformer_m36 | | ++------------------------------------+------------------+ +| caformer_s18 | | ++------------------------------------+------------------+ +| caformer_s36 | | ++------------------------------------+------------------+ +| convformer_b36 | | ++------------------------------------+------------------+ +| convformer_m36 | | ++------------------------------------+------------------+ +| convformer_s18 | | ++------------------------------------+------------------+ +| convformer_s36 | | ++------------------------------------+------------------+ +| convnext_atto | ✅ | ++------------------------------------+------------------+ +| convnext_atto_ols | ✅ | ++------------------------------------+------------------+ +| convnext_atto_rms | ✅ | ++------------------------------------+------------------+ +| convnext_base | ✅ | ++------------------------------------+------------------+ +| convnext_femto | ✅ | ++------------------------------------+------------------+ +| convnext_femto_ols | ✅ | ++------------------------------------+------------------+ +| convnext_large | ✅ | ++------------------------------------+------------------+ +| convnext_large_mlp | ✅ | ++------------------------------------+------------------+ +| convnext_nano | ✅ | ++------------------------------------+------------------+ +| convnext_nano_ols | ✅ | ++------------------------------------+------------------+ +| convnext_pico | ✅ | ++------------------------------------+------------------+ +| convnext_pico_ols | ✅ | ++------------------------------------+------------------+ +| convnext_small | ✅ | ++------------------------------------+------------------+ +| convnext_tiny | ✅ | ++------------------------------------+------------------+ +| convnext_tiny_hnf | ✅ | ++------------------------------------+------------------+ +| convnext_xlarge | ✅ | ++------------------------------------+------------------+ +| convnext_xxlarge | ✅ | ++------------------------------------+------------------+ +| convnext_zepto_rms | ✅ | ++------------------------------------+------------------+ +| convnext_zepto_rms_ols | ✅ | ++------------------------------------+------------------+ +| convnextv2_atto | ✅ | ++------------------------------------+------------------+ +| convnextv2_base | ✅ | ++------------------------------------+------------------+ +| convnextv2_femto | ✅ | ++------------------------------------+------------------+ +| convnextv2_huge | ✅ | ++------------------------------------+------------------+ +| convnextv2_large | ✅ | ++------------------------------------+------------------+ +| convnextv2_nano | ✅ | ++------------------------------------+------------------+ +| convnextv2_pico | ✅ | ++------------------------------------+------------------+ +| convnextv2_small | ✅ | ++------------------------------------+------------------+ +| convnextv2_tiny | ✅ | ++------------------------------------+------------------+ +| davit_base | | ++------------------------------------+------------------+ +| davit_base_fl | | ++------------------------------------+------------------+ +| davit_giant | | ++------------------------------------+------------------+ +| davit_huge | | ++------------------------------------+------------------+ +| davit_huge_fl | | ++------------------------------------+------------------+ +| davit_large | | ++------------------------------------+------------------+ +| davit_small | | ++------------------------------------+------------------+ +| davit_tiny | | ++------------------------------------+------------------+ +| edgenext_base | | ++------------------------------------+------------------+ +| edgenext_small | | ++------------------------------------+------------------+ +| edgenext_small_rw | | ++------------------------------------+------------------+ +| edgenext_x_small | | ++------------------------------------+------------------+ +| edgenext_xx_small | | ++------------------------------------+------------------+ +| efficientformer_l1 | | ++------------------------------------+------------------+ +| efficientformer_l3 | | ++------------------------------------+------------------+ +| efficientformer_l7 | | ++------------------------------------+------------------+ +| efficientformerv2_l | | ++------------------------------------+------------------+ +| efficientformerv2_s0 | | ++------------------------------------+------------------+ +| efficientformerv2_s1 | | ++------------------------------------+------------------+ +| efficientformerv2_s2 | | ++------------------------------------+------------------+ +| efficientvit_b0 | | ++------------------------------------+------------------+ +| efficientvit_b1 | | ++------------------------------------+------------------+ +| efficientvit_b2 | | ++------------------------------------+------------------+ +| efficientvit_b3 | | ++------------------------------------+------------------+ +| efficientvit_l1 | | ++------------------------------------+------------------+ +| efficientvit_l2 | | ++------------------------------------+------------------+ +| efficientvit_l3 | | ++------------------------------------+------------------+ +| fastvit_ma36 | | ++------------------------------------+------------------+ +| fastvit_mci0 | | ++------------------------------------+------------------+ +| fastvit_mci1 | | ++------------------------------------+------------------+ +| fastvit_mci2 | | ++------------------------------------+------------------+ +| fastvit_s12 | | ++------------------------------------+------------------+ +| fastvit_sa12 | | ++------------------------------------+------------------+ +| fastvit_sa24 | | ++------------------------------------+------------------+ +| fastvit_sa36 | | ++------------------------------------+------------------+ +| fastvit_t8 | | ++------------------------------------+------------------+ +| fastvit_t12 | | ++------------------------------------+------------------+ +| focalnet_base_lrf | | ++------------------------------------+------------------+ +| focalnet_base_srf | | ++------------------------------------+------------------+ +| focalnet_huge_fl3 | | ++------------------------------------+------------------+ +| focalnet_huge_fl4 | | ++------------------------------------+------------------+ +| focalnet_large_fl3 | | ++------------------------------------+------------------+ +| focalnet_large_fl4 | | ++------------------------------------+------------------+ +| focalnet_small_lrf | | ++------------------------------------+------------------+ +| focalnet_small_srf | | ++------------------------------------+------------------+ +| focalnet_tiny_lrf | | ++------------------------------------+------------------+ +| focalnet_tiny_srf | | ++------------------------------------+------------------+ +| focalnet_xlarge_fl3 | | ++------------------------------------+------------------+ +| focalnet_xlarge_fl4 | | ++------------------------------------+------------------+ +| hgnet_base | | ++------------------------------------+------------------+ +| hgnet_small | | ++------------------------------------+------------------+ +| hgnet_tiny | | ++------------------------------------+------------------+ +| hgnetv2_b0 | | ++------------------------------------+------------------+ +| hgnetv2_b1 | | ++------------------------------------+------------------+ +| hgnetv2_b2 | | ++------------------------------------+------------------+ +| hgnetv2_b3 | | ++------------------------------------+------------------+ +| hgnetv2_b4 | | ++------------------------------------+------------------+ +| hgnetv2_b5 | | ++------------------------------------+------------------+ +| hgnetv2_b6 | | ++------------------------------------+------------------+ +| hiera_base_224 | | ++------------------------------------+------------------+ +| hiera_base_abswin_256 | | ++------------------------------------+------------------+ +| hiera_base_plus_224 | | ++------------------------------------+------------------+ +| hiera_huge_224 | | ++------------------------------------+------------------+ +| hiera_large_224 | | ++------------------------------------+------------------+ +| hiera_small_224 | | ++------------------------------------+------------------+ +| hiera_small_abswin_256 | | ++------------------------------------+------------------+ +| hiera_tiny_224 | | ++------------------------------------+------------------+ +| hieradet_small | | ++------------------------------------+------------------+ +| inception_next_base | | ++------------------------------------+------------------+ +| inception_next_small | | ++------------------------------------+------------------+ +| inception_next_tiny | | ++------------------------------------+------------------+ +| mambaout_base | | ++------------------------------------+------------------+ +| mambaout_base_plus_rw | | ++------------------------------------+------------------+ +| mambaout_base_short_rw | | ++------------------------------------+------------------+ +| mambaout_base_tall_rw | | ++------------------------------------+------------------+ +| mambaout_base_wide_rw | | ++------------------------------------+------------------+ +| mambaout_femto | | ++------------------------------------+------------------+ +| mambaout_kobe | | ++------------------------------------+------------------+ +| mambaout_small | | ++------------------------------------+------------------+ +| mambaout_small_rw | | ++------------------------------------+------------------+ +| mambaout_tiny | | ++------------------------------------+------------------+ +| mvitv2_base | | ++------------------------------------+------------------+ +| mvitv2_base_cls | | ++------------------------------------+------------------+ +| mvitv2_huge_cls | | ++------------------------------------+------------------+ +| mvitv2_large | | ++------------------------------------+------------------+ +| mvitv2_large_cls | | ++------------------------------------+------------------+ +| mvitv2_small | | ++------------------------------------+------------------+ +| mvitv2_small_cls | | ++------------------------------------+------------------+ +| mvitv2_tiny | | ++------------------------------------+------------------+ +| nextvit_base | | ++------------------------------------+------------------+ +| nextvit_large | | ++------------------------------------+------------------+ +| nextvit_small | | ++------------------------------------+------------------+ +| poolformer_m36 | | ++------------------------------------+------------------+ +| poolformer_m48 | | ++------------------------------------+------------------+ +| poolformer_s12 | | ++------------------------------------+------------------+ +| poolformer_s24 | | ++------------------------------------+------------------+ +| poolformer_s36 | | ++------------------------------------+------------------+ +| poolformerv2_m36 | | ++------------------------------------+------------------+ +| poolformerv2_m48 | | ++------------------------------------+------------------+ +| poolformerv2_s12 | | ++------------------------------------+------------------+ +| poolformerv2_s24 | | ++------------------------------------+------------------+ +| poolformerv2_s36 | | ++------------------------------------+------------------+ +| pvt_v2_b0 | | ++------------------------------------+------------------+ +| pvt_v2_b1 | | ++------------------------------------+------------------+ +| pvt_v2_b2 | | ++------------------------------------+------------------+ +| pvt_v2_b2_li | | ++------------------------------------+------------------+ +| pvt_v2_b3 | | ++------------------------------------+------------------+ +| pvt_v2_b4 | | ++------------------------------------+------------------+ +| pvt_v2_b5 | | ++------------------------------------+------------------+ +| rdnet_base | | ++------------------------------------+------------------+ +| rdnet_large | | ++------------------------------------+------------------+ +| rdnet_small | | ++------------------------------------+------------------+ +| rdnet_tiny | | ++------------------------------------+------------------+ +| repvit_m0_9 | | ++------------------------------------+------------------+ +| repvit_m1 | | ++------------------------------------+------------------+ +| repvit_m1_0 | | ++------------------------------------+------------------+ +| repvit_m1_1 | | ++------------------------------------+------------------+ +| repvit_m1_5 | | ++------------------------------------+------------------+ +| repvit_m2 | | ++------------------------------------+------------------+ +| repvit_m2_3 | | ++------------------------------------+------------------+ +| repvit_m3 | | ++------------------------------------+------------------+ +| sam2_hiera_base_plus | | ++------------------------------------+------------------+ +| sam2_hiera_large | | ++------------------------------------+------------------+ +| sam2_hiera_small | | ++------------------------------------+------------------+ +| sam2_hiera_tiny | | ++------------------------------------+------------------+ +| swin_base_patch4_window7_224 | | ++------------------------------------+------------------+ +| swin_base_patch4_window12_384 | | ++------------------------------------+------------------+ +| swin_large_patch4_window7_224 | | ++------------------------------------+------------------+ +| swin_large_patch4_window12_384 | | ++------------------------------------+------------------+ +| swin_s3_base_224 | | ++------------------------------------+------------------+ +| swin_s3_small_224 | | ++------------------------------------+------------------+ +| swin_s3_tiny_224 | | ++------------------------------------+------------------+ +| swin_small_patch4_window7_224 | | ++------------------------------------+------------------+ +| swin_tiny_patch4_window7_224 | | ++------------------------------------+------------------+ +| swinv2_base_window8_256 | | ++------------------------------------+------------------+ +| swinv2_base_window12_192 | | ++------------------------------------+------------------+ +| swinv2_base_window12to16_192to256 | | ++------------------------------------+------------------+ +| swinv2_base_window12to24_192to384 | | ++------------------------------------+------------------+ +| swinv2_base_window16_256 | | ++------------------------------------+------------------+ +| swinv2_cr_base_224 | | ++------------------------------------+------------------+ +| swinv2_cr_base_384 | | ++------------------------------------+------------------+ +| swinv2_cr_base_ns_224 | | ++------------------------------------+------------------+ +| swinv2_cr_giant_224 | | ++------------------------------------+------------------+ +| swinv2_cr_giant_384 | | ++------------------------------------+------------------+ +| swinv2_cr_huge_224 | | ++------------------------------------+------------------+ +| swinv2_cr_huge_384 | | ++------------------------------------+------------------+ +| swinv2_cr_large_224 | | ++------------------------------------+------------------+ +| swinv2_cr_large_384 | | ++------------------------------------+------------------+ +| swinv2_cr_small_224 | | ++------------------------------------+------------------+ +| swinv2_cr_small_384 | | ++------------------------------------+------------------+ +| swinv2_cr_small_ns_224 | | ++------------------------------------+------------------+ +| swinv2_cr_small_ns_256 | | ++------------------------------------+------------------+ +| swinv2_cr_tiny_224 | | ++------------------------------------+------------------+ +| swinv2_cr_tiny_384 | | ++------------------------------------+------------------+ +| swinv2_cr_tiny_ns_224 | | ++------------------------------------+------------------+ +| swinv2_large_window12_192 | | ++------------------------------------+------------------+ +| swinv2_large_window12to16_192to256 | | ++------------------------------------+------------------+ +| swinv2_large_window12to24_192to384 | | ++------------------------------------+------------------+ +| swinv2_small_window8_256 | | ++------------------------------------+------------------+ +| swinv2_small_window16_256 | | ++------------------------------------+------------------+ +| swinv2_tiny_window8_256 | | ++------------------------------------+------------------+ +| swinv2_tiny_window16_256 | | ++------------------------------------+------------------+ +| tiny_vit_5m_224 | | ++------------------------------------+------------------+ +| tiny_vit_11m_224 | | ++------------------------------------+------------------+ +| tiny_vit_21m_224 | | ++------------------------------------+------------------+ +| tiny_vit_21m_384 | | ++------------------------------------+------------------+ +| tiny_vit_21m_512 | | ++------------------------------------+------------------+ +| tresnet_l | | ++------------------------------------+------------------+ +| tresnet_m | | ++------------------------------------+------------------+ +| tresnet_v2_l | | ++------------------------------------+------------------+ +| tresnet_xl | | ++------------------------------------+------------------+ +| twins_pcpvt_base | | ++------------------------------------+------------------+ +| twins_pcpvt_large | | ++------------------------------------+------------------+ +| twins_pcpvt_small | | ++------------------------------------+------------------+ +| twins_svt_base | | ++------------------------------------+------------------+ +| twins_svt_large | | ++------------------------------------+------------------+ +| twins_svt_small | | ++------------------------------------+------------------+ diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 2e9ed454..06bd267b 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -31,9 +31,6 @@ transformer-like models). - Certain models (e.g., TResNet, DLA) require special handling to ensure correct feature indexing. -- Most `timm` models output features in (B, C, H, W) format. However, some - (e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is - currently unsupported. """ from typing import Any @@ -85,12 +82,16 @@ def __init__( out_indices=tuple(range(depth)), ) + # not all models support output stride argument, drop it by default if output_stride == 32: common_kwargs.pop("output_stride") # Load a preliminary model to determine its feature hierarchy structure. self.model = timm.create_model(name, features_only=True) + # Check if the model's output is in channel-last format (B, H, W, C). + self._is_channel_last = getattr(self.model, "output_fmt", None) == "NHWC" + # Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale) # rather than a traditional CNN hierarchy (starting at 1/2 scale). if len(self.model.feature_info.channels()) == 5: @@ -140,6 +141,12 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """ features = self.model(x) + if self._is_channel_last: + # Convert to channel-first (B, C, H, W). + features = [ + feature.permute(0, 3, 1, 2).contiguous() for feature in features + ] + if self._is_transformer_style: # Models using a transformer-like hierarchy may not generate # all expected feature maps. Insert a dummy feature map to ensure From 51a4d7bab379c02717e0da16a8a5ba553e349039 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Wed, 18 Dec 2024 08:45:54 +0800 Subject: [PATCH 08/14] Update encoders_timm.rst --- docs/encoders_timm.rst | 98 +++++++++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 29 deletions(-) diff --git a/docs/encoders_timm.rst b/docs/encoders_timm.rst index 4866d6e1..31c8396e 100644 --- a/docs/encoders_timm.rst +++ b/docs/encoders_timm.rst @@ -9,7 +9,7 @@ however, not all models are supported Below is a table of suitable encoders (for DeepLabV3, DeepLabV3+, and PAN dilation support is needed also) -Total number of encoders: 792 (579+213) +Total number of encoders: 812 (593+219) .. note:: @@ -99,6 +99,8 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | cs3sedarknet_xdw | ✅ | +----------------------------------+------------------+ +| cspdarknet53 | ✅ | ++----------------------------------+------------------+ | cspresnet50 | ✅ | +----------------------------------+------------------+ | cspresnet50d | ✅ | @@ -107,6 +109,14 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | cspresnext50 | ✅ | +----------------------------------+------------------+ +| darknet17 | ✅ | ++----------------------------------+------------------+ +| darknet21 | ✅ | ++----------------------------------+------------------+ +| darknet53 | ✅ | ++----------------------------------+------------------+ +| darknetaa53 | ✅ | ++----------------------------------+------------------+ | densenet121 | | +----------------------------------+------------------+ | densenet161 | | @@ -189,14 +199,6 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | eca_vovnet39b | | +----------------------------------+------------------+ -| ecaresnet101d | ✅ | -+----------------------------------+------------------+ -| ecaresnet101d_pruned | ✅ | -+----------------------------------+------------------+ -| ecaresnet200d | ✅ | -+----------------------------------+------------------+ -| ecaresnet269d | ✅ | -+----------------------------------+------------------+ | ecaresnet26t | ✅ | +----------------------------------+------------------+ | ecaresnet50d | ✅ | @@ -205,6 +207,14 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | ecaresnet50t | ✅ | +----------------------------------+------------------+ +| ecaresnet101d | ✅ | ++----------------------------------+------------------+ +| ecaresnet101d_pruned | ✅ | ++----------------------------------+------------------+ +| ecaresnet200d | ✅ | ++----------------------------------+------------------+ +| ecaresnet269d | ✅ | ++----------------------------------+------------------+ | ecaresnetlight | ✅ | +----------------------------------+------------------+ | ecaresnext26t_32x4d | ✅ | @@ -213,10 +223,10 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | efficientnet_b0 | ✅ | +----------------------------------+------------------+ -| efficientnet_b0_g16_evos | ✅ | -+----------------------------------+------------------+ | efficientnet_b0_g8_gn | ✅ | +----------------------------------+------------------+ +| efficientnet_b0_g16_evos | ✅ | ++----------------------------------+------------------+ | efficientnet_b0_gn | ✅ | +----------------------------------+------------------+ | efficientnet_b1 | ✅ | @@ -333,12 +343,12 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | ghostnet_130 | | +----------------------------------+------------------+ -| ghostnetv2_050 | | -+----------------------------------+------------------+ | ghostnetv2_100 | | +----------------------------------+------------------+ | ghostnetv2_130 | | +----------------------------------+------------------+ +| ghostnetv2_160 | | ++----------------------------------+------------------+ | halo2botnet50ts_256 | ✅ | +----------------------------------+------------------+ | halonet26t | ✅ | @@ -711,14 +721,14 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | regnety_160 | ✅ | +----------------------------------+------------------+ -| regnety_1280 | ✅ | -+----------------------------------+------------------+ -| regnety_2560 | ✅ | -+----------------------------------+------------------+ | regnety_320 | ✅ | +----------------------------------+------------------+ | regnety_640 | ✅ | +----------------------------------+------------------+ +| regnety_1280 | ✅ | ++----------------------------------+------------------+ +| regnety_2560 | ✅ | ++----------------------------------+------------------+ | regnetz_005 | ✅ | +----------------------------------+------------------+ | regnetz_040 | ✅ | @@ -733,12 +743,12 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | regnetz_c16_evos | ✅ | +----------------------------------+------------------+ -| regnetz_d32 | ✅ | -+----------------------------------+------------------+ | regnetz_d8 | ✅ | +----------------------------------+------------------+ | regnetz_d8_evos | ✅ | +----------------------------------+------------------+ +| regnetz_d32 | ✅ | ++----------------------------------+------------------+ | regnetz_e8 | ✅ | +----------------------------------+------------------+ | repghostnet_050 | | @@ -837,12 +847,12 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | resnet50 | ✅ | +----------------------------------+------------------+ -| resnet50_gn | ✅ | -+----------------------------------+------------------+ | resnet50_clip | ✅ | +----------------------------------+------------------+ | resnet50_clip_gap | ✅ | +----------------------------------+------------------+ +| resnet50_gn | ✅ | ++----------------------------------+------------------+ | resnet50_mlp | ✅ | +----------------------------------+------------------+ | resnet50c | ✅ | @@ -1001,6 +1011,8 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | sebotnet33ts_256 | ✅ | +----------------------------------+------------------+ +| sedarknet21 | ✅ | ++----------------------------------+------------------+ | sehalonet33ts | ✅ | +----------------------------------+------------------+ | selecsls42 | | @@ -1045,14 +1057,6 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | seresnetaa50d | ✅ | +----------------------------------+------------------+ -| seresnext101_32x4d | ✅ | -+----------------------------------+------------------+ -| seresnext101_32x8d | ✅ | -+----------------------------------+------------------+ -| seresnext101_64x4d | ✅ | -+----------------------------------+------------------+ -| seresnext101d_32x8d | ✅ | -+----------------------------------+------------------+ | seresnext26d_32x4d | ✅ | +----------------------------------+------------------+ | seresnext26t_32x4d | ✅ | @@ -1061,6 +1065,14 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | seresnext50_32x4d | ✅ | +----------------------------------+------------------+ +| seresnext101_32x4d | ✅ | ++----------------------------------+------------------+ +| seresnext101_32x8d | ✅ | ++----------------------------------+------------------+ +| seresnext101_64x4d | ✅ | ++----------------------------------+------------------+ +| seresnext101d_32x8d | ✅ | ++----------------------------------+------------------+ | seresnextaa101d_32x8d | ✅ | +----------------------------------+------------------+ | seresnextaa201d_32x8d | ✅ | @@ -1163,6 +1175,22 @@ These models typically produce feature maps at the following downsampling scales +----------------------------------+------------------+ | tinynet_e | ✅ | +----------------------------------+------------------+ +| vgg11 | | ++----------------------------------+------------------+ +| vgg11_bn | | ++----------------------------------+------------------+ +| vgg13 | | ++----------------------------------+------------------+ +| vgg13_bn | | ++----------------------------------+------------------+ +| vgg16 | | ++----------------------------------+------------------+ +| vgg16_bn | | ++----------------------------------+------------------+ +| vgg19 | | ++----------------------------------+------------------+ +| vgg19_bn | | ++----------------------------------+------------------+ | vovnet39a | | +----------------------------------+------------------+ | vovnet57a | | @@ -1440,6 +1468,18 @@ Transformer-style models (e.g., Swin Transformer, ConvNeXt) typically produce fe +------------------------------------+------------------+ | mvitv2_tiny | | +------------------------------------+------------------+ +| nest_base | | ++------------------------------------+------------------+ +| nest_base_jx | | ++------------------------------------+------------------+ +| nest_small | | ++------------------------------------+------------------+ +| nest_small_jx | | ++------------------------------------+------------------+ +| nest_tiny | | ++------------------------------------+------------------+ +| nest_tiny_jx | | ++------------------------------------+------------------+ | nextvit_base | | +------------------------------------+------------------+ | nextvit_large | | From 8b0fece33cfabf89b75c12c6d929b53b3caa692e Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Wed, 18 Dec 2024 08:49:55 +0800 Subject: [PATCH 09/14] Update timm_universal.py --- .../encoders/timm_universal.py | 133 +++++++++++------- 1 file changed, 79 insertions(+), 54 deletions(-) diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 06bd267b..92d6dc28 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -1,36 +1,29 @@ """ -TimmUniversalEncoder provides a unified feature extraction interface built on the -`timm` library, supporting various backbone architectures, including traditional -CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy -(e.g., Swin Transformer, ConvNeXt). +TimmUniversalEncoder provides a unified feature extraction interface built on the +`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style +models (e.g., Swin Transformer, ConvNeXt). -This encoder produces standardized multi-level feature maps, facilitating integration -with semantic segmentation tasks. It allows configuring the number of feature extraction -stages (`depth`) and adjusting `output_stride` when supported. +This encoder produces consistent multi-level feature maps for semantic segmentation tasks. +It allows configuring the number of feature extraction stages (`depth`) and adjusting +`output_stride` when supported. Key Features: -- Flexible model selection through `timm.create_model`. -- A unified interface that outputs consistent, multi-level features even if the - underlying model differs in its feature hierarchy. -- Automatic alignment: If a model lacks certain early-stage features (for example, - modern architectures that start from a 1/4 scale rather than 1/2 scale), the encoder - inserts dummy features to maintain consistency with traditional CNN structures. -- Easy access to channel information: Use the `out_channels` property to retrieve - the number of channels at each feature stage. +- Flexible model selection using `timm.create_model`. +- Unified multi-level output across different model hierarchies. +- Automatic alignment for inconsistent feature scales: + - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale. + - VGG-style models (include scale-1 features): Align outputs for compatibility. +- Easy access to feature scale information via the `reduction` property. Feature Scale Differences: -- Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16, - and 1/32 scales. -- Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often - start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale - feature. TimmUniversalEncoder compensates for this omission to ensure a unified - multi-stage output. +- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32. +- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale. +- VGG-style models: Include scale-1 features (input resolution). Notes: -- Not all models support modifying `output_stride` (especially transformer-based or - transformer-like models). -- Certain models (e.g., TResNet, DLA) require special handling to ensure correct - feature indexing. +- `output_stride` is unsupported in some models, especially transformer-based architectures. +- Special handling for models like TResNet and DLA to ensure correct feature indexing. +- VGG-style models use `_is_skip_first` to align scale-1 features with standard outputs. """ from typing import Any @@ -42,14 +35,13 @@ class TimmUniversalEncoder(nn.Module): """ - A universal encoder built on the `timm` library, designed to adapt to a wide variety of - model architectures, including both traditional CNNs and those that follow a - transformer-like hierarchy. + A universal encoder leveraging the `timm` library for feature extraction from + various model architectures, including traditional-style and transformer-style models. Features: - - Supports flexible depth and output stride for feature extraction. - - Automatically adjusts to input/output channel structures based on the model type. - - Compatible with both convolutional and transformer-like encoders. + - Supports configurable depth and output stride. + - Ensures consistent multi-level feature extraction across diverse models. + - Compatible with convolutional and transformer-like backbones. """ def __init__( @@ -65,15 +57,16 @@ def __init__( Initialize the encoder. Args: - name (str): Name of the model to be loaded from the `timm` library. - pretrained (bool): If True, loads pretrained weights. + name (str): Model name to load from `timm`. + pretrained (bool): Load pretrained weights (default: True). in_channels (int): Number of input channels (default: 3 for RGB). - depth (int): Number of feature extraction stages (default: 5). + depth (int): Number of feature stages to extract (default: 5). output_stride (int): Desired output stride (default: 32). - **kwargs: Additional keyword arguments for `timm.create_model`. + **kwargs: Additional arguments passed to `timm.create_model`. """ super().__init__() + # Default model configuration for feature extraction common_kwargs = dict( in_chans=in_channels, features_only=True, @@ -82,24 +75,37 @@ def __init__( out_indices=tuple(range(depth)), ) - # not all models support output stride argument, drop it by default + # Not all models support output stride argument, drop it by default if output_stride == 32: common_kwargs.pop("output_stride") - # Load a preliminary model to determine its feature hierarchy structure. + # Load a temporary model to analyze its feature hierarchy self.model = timm.create_model(name, features_only=True) - # Check if the model's output is in channel-last format (B, H, W, C). + # Check if model output is in channel-last format (NHWC) self._is_channel_last = getattr(self.model, "output_fmt", None) == "NHWC" - # Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale) - # rather than a traditional CNN hierarchy (starting at 1/2 scale). - if len(self.model.feature_info.channels()) == 5: + # Determine the model's downsampling pattern and set hierarchy flags + encoder_stage = len(self.model.feature_info.reduction()) + reduction_scales = self.model.feature_info.reduction() + + if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]: + # Transformer-style downsampling: scales (4, 8, 16, 32) + self._is_transformer_style = True + self._is_skip_first = False + elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]: + # Traditional-style downsampling: scales (2, 4, 8, 16, 32) self._is_transformer_style = False + self._is_skip_first = False + elif reduction_scales == [2 ** i for i in range(encoder_stage)]: + # Models including scale 1: scales (1, 2, 4, 8, 16, 32) + self._is_transformer_style = False + self._is_skip_first = True else: - self._is_transformer_style = True + raise ValueError("Unsupported model downsampling pattern.") if self._is_transformer_style: + # Transformer-like models (start at scale 4) if "tresnet" in name: # 'tresnet' models start feature extraction at stage 1, # so out_indices=(1, 2, 3, 4) for depth=5. @@ -119,11 +125,17 @@ def __init__( if "dla" in name: # For 'dla' models, out_indices starts at 0 and matches the input size. common_kwargs["out_indices"] = tuple(range(1, depth + 1)) + if self._is_skip_first: + common_kwargs["out_indices"] = tuple(range(depth + 1)) self.model = timm.create_model( name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) ) - self._out_channels = [in_channels] + self.model.feature_info.channels() + + if self._is_skip_first: + self._out_channels = self.model.feature_info.channels() + else: + self._out_channels = [in_channels] + self.model.feature_info.channels() self._in_channels = in_channels self._depth = depth @@ -131,30 +143,30 @@ def __init__( def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """ - Pass the input through the encoder and return extracted features. + Forward pass to extract multi-stage features. Args: x (torch.Tensor): Input tensor of shape (B, C, H, W). Returns: - list[torch.Tensor]: A list of feature maps extracted at various scales. + list[torch.Tensor]: List of feature maps at different scales. """ features = self.model(x) + # Convert NHWC to NCHW if needed if self._is_channel_last: - # Convert to channel-first (B, C, H, W). features = [ feature.permute(0, 3, 1, 2).contiguous() for feature in features ] + # Add dummy feature for scale 1/2 if missing (transformer-style models) if self._is_transformer_style: - # Models using a transformer-like hierarchy may not generate - # all expected feature maps. Insert a dummy feature map to ensure - # compatibility with decoders expecting a 5-level pyramid. B, _, H, W = x.shape dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device) - features = [x] + [dummy] + features - else: + features = [dummy] + features + + # Add input tensor as scale 1 feature if `self._is_skip_first` is False + if not self._is_skip_first: features = [x] + features return features @@ -162,22 +174,35 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: @property def out_channels(self) -> list[int]: """ + Returns the number of output channels for each feature stage. + Returns: - list[int]: A list of output channels for each stage of the encoder, - including the input channels at the first stage. + list[int]: A list of channel dimensions at each scale. """ return self._out_channels @property def output_stride(self) -> int: """ + Returns the effective output stride based on the model depth. + Returns: - int: The effective output stride of the encoder, considering the depth. + int: The effective output stride. """ return min(self._output_stride, 2**self._depth) def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: + """ + Merge two dictionaries, ensuring no duplicate keys exist. + + Args: + a (dict): Base dictionary. + b (dict): Additional parameters to merge. + + Returns: + dict: A merged dictionary. + """ duplicates = a.keys() & b.keys() if duplicates: raise ValueError(f"'{duplicates}' already specified internally") From 330e6e5ae3dfd1708f22231e93e1eba5905e8cd5 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Wed, 18 Dec 2024 19:56:30 +0800 Subject: [PATCH 10/14] Fix ruff style --- .../encoders/timm_universal.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 92d6dc28..18f6328e 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -1,29 +1,29 @@ """ -TimmUniversalEncoder provides a unified feature extraction interface built on the -`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style +TimmUniversalEncoder provides a unified feature extraction interface built on the +`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style models (e.g., Swin Transformer, ConvNeXt). -This encoder produces consistent multi-level feature maps for semantic segmentation tasks. -It allows configuring the number of feature extraction stages (`depth`) and adjusting +This encoder produces consistent multi-level feature maps for semantic segmentation tasks. +It allows configuring the number of feature extraction stages (`depth`) and adjusting `output_stride` when supported. Key Features: - Flexible model selection using `timm.create_model`. -- Unified multi-level output across different model hierarchies. +- Unified multi-level output across different model hierarchies. - Automatic alignment for inconsistent feature scales: - - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale. - - VGG-style models (include scale-1 features): Align outputs for compatibility. + - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale. + - VGG-style models (include scale-1 features): Align outputs for compatibility. - Easy access to feature scale information via the `reduction` property. Feature Scale Differences: -- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32. -- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale. +- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32. +- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale. - VGG-style models: Include scale-1 features (input resolution). Notes: -- `output_stride` is unsupported in some models, especially transformer-based architectures. -- Special handling for models like TResNet and DLA to ensure correct feature indexing. -- VGG-style models use `_is_skip_first` to align scale-1 features with standard outputs. +- `output_stride` is unsupported in some models, especially transformer-based architectures. +- Special handling for models like TResNet and DLA to ensure correct feature indexing. +- VGG-style models use `_is_vgg_style` to align scale-1 features with standard outputs. """ from typing import Any @@ -35,7 +35,7 @@ class TimmUniversalEncoder(nn.Module): """ - A universal encoder leveraging the `timm` library for feature extraction from + A universal encoder leveraging the `timm` library for feature extraction from various model architectures, including traditional-style and transformer-style models. Features: @@ -92,15 +92,15 @@ def __init__( if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]: # Transformer-style downsampling: scales (4, 8, 16, 32) self._is_transformer_style = True - self._is_skip_first = False + self._is_vgg_style = False elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]: # Traditional-style downsampling: scales (2, 4, 8, 16, 32) self._is_transformer_style = False - self._is_skip_first = False - elif reduction_scales == [2 ** i for i in range(encoder_stage)]: - # Models including scale 1: scales (1, 2, 4, 8, 16, 32) + self._is_vgg_style = False + elif reduction_scales == [2**i for i in range(encoder_stage)]: + # Vgg-style models including scale 1: scales (1, 2, 4, 8, 16, 32) self._is_transformer_style = False - self._is_skip_first = True + self._is_vgg_style = True else: raise ValueError("Unsupported model downsampling pattern.") @@ -125,14 +125,14 @@ def __init__( if "dla" in name: # For 'dla' models, out_indices starts at 0 and matches the input size. common_kwargs["out_indices"] = tuple(range(1, depth + 1)) - if self._is_skip_first: + if self._is_vgg_style: common_kwargs["out_indices"] = tuple(range(depth + 1)) self.model = timm.create_model( name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) ) - if self._is_skip_first: + if self._is_vgg_style: self._out_channels = self.model.feature_info.channels() else: self._out_channels = [in_channels] + self.model.feature_info.channels() @@ -164,9 +164,9 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: B, _, H, W = x.shape dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device) features = [dummy] + features - - # Add input tensor as scale 1 feature if `self._is_skip_first` is False - if not self._is_skip_first: + + # Add input tensor as scale 1 feature if `self._is_vgg_style` is False + if not self._is_vgg_style: features = [x] + features return features From d8ea35f5bf714941baedde3326fe6b0f67496d90 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 19 Dec 2024 00:55:18 +0800 Subject: [PATCH 11/14] Update timm_universal.py 1. rename temporary model 2. create temporary model on meta device to speed up --- .../encoders/timm_universal.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 18f6328e..abcdb467 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -80,14 +80,18 @@ def __init__( common_kwargs.pop("output_stride") # Load a temporary model to analyze its feature hierarchy - self.model = timm.create_model(name, features_only=True) + try: + with torch.device("meta"): + tmp_model = timm.create_model(name, features_only=True) + except Exception: + tmp_model = timm.create_model(name, features_only=True) # Check if model output is in channel-last format (NHWC) - self._is_channel_last = getattr(self.model, "output_fmt", None) == "NHWC" + self._is_channel_last = getattr(tmp_model, "output_fmt", None) == "NHWC" # Determine the model's downsampling pattern and set hierarchy flags - encoder_stage = len(self.model.feature_info.reduction()) - reduction_scales = self.model.feature_info.reduction() + encoder_stage = len(tmp_model.feature_info.reduction()) + reduction_scales = tmp_model.feature_info.reduction() if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]: # Transformer-style downsampling: scales (4, 8, 16, 32) From e7bc6e048e7b40aad74d6f9132d1019f7b6bb92b Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 19 Dec 2024 08:33:21 +0800 Subject: [PATCH 12/14] Add tests/test_models & fix type --- .../encoders/timm_universal.py | 2 +- tests/test_models.py | 24 +++++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index abcdb467..9bdcb188 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -91,7 +91,7 @@ def __init__( # Determine the model's downsampling pattern and set hierarchy flags encoder_stage = len(tmp_model.feature_info.reduction()) - reduction_scales = tmp_model.feature_info.reduction() + reduction_scales = list(tmp_model.feature_info.reduction()) if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]: # Transformer-style downsampling: scales (4, 8, 16, 32) diff --git a/tests/test_models.py b/tests/test_models.py index 10f697b8..80fe44ed 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,8 +13,11 @@ def get_encoders(): ] encoders = smp.encoders.get_encoder_names() encoders = [e for e in encoders if e not in exclude_encoders] - encoders.append("tu-resnet34") # for timm universal encoder - return encoders + encoders.append("tu-resnet34") # for timm universal traditional-like encoder + encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder + encoders.append("tu-darknet17") # for timm universal vgg-like encoder + encoders.append("mit_b0") + return encoders[-3:] ENCODERS = get_encoders() @@ -80,16 +83,12 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs): or model_class is smp.MAnet ): kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] - if model_class in [smp.UnetPlusPlus, smp.Linknet] and encoder_name.startswith( - "mit_b" - ): - return # skip mit_b* - if ( - model_class is smp.FPN - and encoder_name.startswith("mit_b") - and encoder_depth != 5 - ): - return # skip mit_b* + if model_class in [smp.UnetPlusPlus, smp.Linknet]: + if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): + return # skip transformer-like model* + if model_class is smp.FPN and encoder_depth != 5: + if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): + return # skip transformer-like model* model = model_class( encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs ) @@ -180,7 +179,6 @@ def test_dilation(encoder_name): or encoder_name.startswith("vgg") or encoder_name.startswith("densenet") or encoder_name.startswith("timm-res") - or encoder_name.startswith("mit_b") ): return From dd25aa2dad6f7bc292ae867ddfd25a13b5d9fe47 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 19 Dec 2024 08:34:16 +0800 Subject: [PATCH 13/14] Update test_models.py --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 80fe44ed..6495bee7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,7 +17,7 @@ def get_encoders(): encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder encoders.append("tu-darknet17") # for timm universal vgg-like encoder encoders.append("mit_b0") - return encoders[-3:] + return encoders ENCODERS = get_encoders() From f55eb134bfefa91fe9071dab78f476922fde5404 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 19 Dec 2024 08:35:24 +0800 Subject: [PATCH 14/14] Update test_models.py --- tests/test_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 6495bee7..cb495fbb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -16,7 +16,6 @@ def get_encoders(): encoders.append("tu-resnet34") # for timm universal traditional-like encoder encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder encoders.append("tu-darknet17") # for timm universal vgg-like encoder - encoders.append("mit_b0") return encoders