Skip to content

Add forward_intermediates support for xcit, cait, and volo. #2162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@

# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
FEAT_INTER_FILTERS = [
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*',
'cait_*', 'xcit_*', 'volo_*',
]

# transformer models don't support many of the spatial / feature based model functionalities
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
Expand Down
6 changes: 4 additions & 2 deletions timm/models/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
):
# setup feature hooks
self._feature_outputs = defaultdict(OrderedDict)
self._handles = []
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
Expand All @@ -173,11 +174,12 @@ def __init__(
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h.get('hook_type', default_hook_type)
if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
handle = m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':
m.register_forward_hook(hook_fn)
handle = m.register_forward_hook(hook_fn)
else:
assert False, "Unsupported hook type"
self._handles.append(handle)

def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
Expand Down
85 changes: 80 additions & 5 deletions timm/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from functools import partial
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs

Expand Down Expand Up @@ -246,8 +248,8 @@ def __init__(
in_chans=in_chans,
embed_dim=embed_dim,
)

num_patches = self.patch_embed.num_patches
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
Expand All @@ -268,6 +270,7 @@ def __init__(
mlp_block=mlp_block,
init_values=init_values,
) for i in range(depth)])
self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)]

self.blocks_token_only = nn.ModuleList([block_layers_token(
dim=embed_dim,
Expand All @@ -283,7 +286,6 @@ def __init__(

self.norm = norm_layer(embed_dim)

self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

Expand Down Expand Up @@ -336,6 +338,80 @@ def reset_classifier(self, num_classes, global_pool=None):
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.

Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)

# forward pass
B, _, height, width = x.shape
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)

# process intermediates
if reshape:
# reshape to BCHW output format
H, W = self.patch_embed.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]

if intermediates_only:
return intermediates

# NOTE not supporting return of class tokens
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
for i, blk in enumerate(self.blocks_token_only):
cls_tokens = blk(x, cls_tokens)
x = torch.cat((cls_tokens, x), dim=1)
x = self.norm(x)

return x, intermediates

def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.blocks_token_only = nn.ModuleList() # prune token blocks with head
self.head = nn.Identity()
return take_indices

def forward_features(self, x):
x = self.patch_embed(x)
x = x + self.pos_embed
Expand Down Expand Up @@ -373,14 +449,13 @@ def checkpoint_filter_fn(state_dict, model=None):


def _create_cait(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')

out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
Cait,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model
Expand Down
Loading
Loading