Skip to content

Commit 4b2565e

Browse files
committed
More forward_intermediates() / FeatureGetterNet work
* include relpos vit * refactor reduction / size calcs so hybrid vits work and dynamic_img_size works * fix -ve feature indices when pruning * fix mvitv2 w/ class token * refine naming * add tests
1 parent ef9c6fb commit 4b2565e

11 files changed

+339
-86
lines changed

tests/test_models.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@
4747
torch._C._jit_set_profiling_executor(True)
4848
torch._C._jit_set_profiling_mode(False)
4949

50+
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
51+
FEAT_INTER_FILTERS = [
52+
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
53+
]
54+
5055
# transformer models don't support many of the spatial / feature based model functionalities
5156
NON_STD_FILTERS = [
5257
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
@@ -380,6 +385,72 @@ def test_model_forward_features(model_name, batch_size):
380385
assert not torch.isnan(o).any()
381386

382387

388+
@pytest.mark.features
389+
@pytest.mark.timeout(120)
390+
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, include_tags=True))
391+
@pytest.mark.parametrize('batch_size', [1])
392+
def test_model_forward_intermediates_features(model_name, batch_size):
393+
"""Run a single forward pass with each model in feature extraction mode"""
394+
model = create_model(model_name, pretrained=False, features_only=True)
395+
model.eval()
396+
print(model.feature_info.out_indices)
397+
expected_channels = model.feature_info.channels()
398+
expected_reduction = model.feature_info.reduction()
399+
400+
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
401+
if max(input_size) > MAX_FFEAT_SIZE:
402+
pytest.skip("Fixed input size model > limit.")
403+
output_fmt = getattr(model, 'output_fmt', 'NCHW')
404+
feat_axis = get_channel_dim(output_fmt)
405+
spatial_axis = get_spatial_dim(output_fmt)
406+
import math
407+
408+
outputs = model(torch.randn((batch_size, *input_size)))
409+
assert len(expected_channels) == len(outputs)
410+
spatial_size = input_size[-2:]
411+
for e, r, o in zip(expected_channels, expected_reduction, outputs):
412+
print(o.shape)
413+
assert e == o.shape[feat_axis]
414+
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
415+
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
416+
assert o.shape[0] == batch_size
417+
assert not torch.isnan(o).any()
418+
419+
420+
@pytest.mark.features
421+
@pytest.mark.timeout(120)
422+
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, include_tags=True))
423+
@pytest.mark.parametrize('batch_size', [1])
424+
def test_model_forward_intermediates(model_name, batch_size):
425+
"""Run a single forward pass with each model in feature extraction mode"""
426+
model = create_model(model_name, pretrained=False)
427+
model.eval()
428+
feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info))
429+
expected_channels = feature_info.channels()
430+
expected_reduction = feature_info.reduction()
431+
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
432+
433+
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
434+
if max(input_size) > MAX_FFEAT_SIZE:
435+
pytest.skip("Fixed input size model > limit.")
436+
output_fmt = getattr(model, 'output_fmt', 'NCHW')
437+
feat_axis = get_channel_dim(output_fmt)
438+
spatial_axis = get_spatial_dim(output_fmt)
439+
import math
440+
441+
output, intermediates = model.forward_intermediates(
442+
torch.randn((batch_size, *input_size)),
443+
)
444+
assert len(expected_channels) == len(intermediates)
445+
spatial_size = input_size[-2:]
446+
for e, r, o in zip(expected_channels, expected_reduction, intermediates):
447+
assert e == o.shape[feat_axis]
448+
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
449+
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
450+
assert o.shape[0] == batch_size
451+
assert not torch.isnan(o).any()
452+
453+
383454
def _create_fx_model(model, train=False):
384455
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
385456
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output

timm/layers/patch_embed.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Hacked together by / Copyright 2020 Ross Wightman
1010
"""
1111
import logging
12+
import math
1213
from typing import Callable, List, Optional, Tuple, Union
1314

1415
import torch
@@ -65,6 +66,21 @@ def __init__(
6566
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
6667
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
6768

69+
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
70+
if as_scalar:
71+
return max(self.patch_size)
72+
else:
73+
return self.patch_size
74+
75+
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
76+
""" Get grid (feature) size for given image size taking account of dynamic padding.
77+
NOTE: must be torchscript compatible so using fixed tuple indexing
78+
"""
79+
if self.dynamic_img_pad:
80+
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
81+
else:
82+
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
83+
6884
def forward(self, x):
6985
B, C, H, W = x.shape
7086
if self.img_size is not None:
@@ -127,13 +143,13 @@ def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
127143
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")
128144

129145
x = self.proj(x)
130-
grid_size = x.shape[-2:]
146+
feat_size = x.shape[-2:]
131147
if self.flatten:
132148
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
133149
elif self.output_fmt != Format.NCHW:
134150
x = nchw_to(x, self.output_fmt)
135151
x = self.norm(x)
136-
return x, grid_size
152+
return x, feat_size
137153

138154

139155
def resample_patch_embed(

timm/models/_features.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
]
2727

2828

29-
def _take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[Set[int], int]:
29+
def _take_indices(
30+
num_blocks: int,
31+
n: Optional[Union[int, List[int], Tuple[int]]],
32+
) -> Tuple[Set[int], int]:
3033
if isinstance(n, int):
3134
assert n >= 0
3235
take_indices = {x for x in range(num_blocks - n, num_blocks)}
@@ -35,7 +38,10 @@ def _take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tupl
3538
return take_indices, max(take_indices)
3639

3740

38-
def _take_indices_jit(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]:
41+
def _take_indices_jit(
42+
num_blocks: int,
43+
n: Union[int, List[int], Tuple[int]],
44+
) -> Tuple[List[int], int]:
3945
if isinstance(n, int):
4046
assert n >= 0
4147
take_indices = [num_blocks - n + i for i in range(n)]
@@ -47,12 +53,17 @@ def _take_indices_jit(n: Union[int, List[int], Tuple[int]], num_blocks: int) ->
4753
return take_indices, max(take_indices)
4854

4955

50-
def feature_take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]:
56+
def feature_take_indices(
57+
num_blocks: int,
58+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
59+
) -> Tuple[List[int], int]:
60+
if indices is None:
61+
indices = num_blocks # all blocks if None
5162
if torch.jit.is_scripting():
52-
return _take_indices_jit(n, num_blocks)
63+
return _take_indices_jit(num_blocks, indices)
5364
else:
5465
# NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
55-
return _take_indices(n, num_blocks)
66+
return _take_indices(num_blocks, indices)
5667

5768

5869
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
@@ -443,10 +454,12 @@ def __init__(
443454
"""
444455
super().__init__()
445456
if prune and hasattr(model, 'prune_intermediate_layers'):
446-
model.prune_intermediate_layers(
457+
# replace out_indices after they've been normalized, -ve indices will be invalid after prune
458+
out_indices = model.prune_intermediate_layers(
447459
out_indices,
448460
prune_norm=not norm,
449461
)
462+
out_indices = list(out_indices)
450463
self.feature_info = _get_feature_info(model, out_indices)
451464
self.model = model
452465
self.out_indices = out_indices
@@ -458,9 +471,9 @@ def __init__(
458471
def forward(self, x):
459472
features = self.model.forward_intermediates(
460473
x,
461-
n=self.out_indices,
474+
indices=self.out_indices,
462475
norm=self.norm,
463476
output_fmt=self.output_fmt,
464-
features_only=True,
477+
intermediates_only=True,
465478
)
466479
return features

timm/models/beit.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def __init__(
302302
embed_dim=embed_dim,
303303
)
304304
num_patches = self.patch_embed.num_patches
305+
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
305306

306307
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
307308
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
@@ -334,7 +335,7 @@ def __init__(
334335
)
335336
for i in range(depth)])
336337
self.feature_info = [
337-
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
338+
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
338339

339340
use_fc_norm = self.global_pool == 'avg'
340341
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
@@ -403,33 +404,30 @@ def reset_classifier(self, num_classes, global_pool=None):
403404
def forward_intermediates(
404405
self,
405406
x: torch.Tensor,
406-
n: Optional[Union[int, List[int], Tuple[int]]] = None,
407+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
407408
return_prefix_tokens: bool = False,
408409
norm: bool = False,
409410
stop_early: bool = True,
410411
output_fmt: str = 'NCHW',
411-
features_only: bool = False,
412+
intermediates_only: bool = False,
412413
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
413414
""" Forward features that returns intermediates.
414415
415416
Args:
416417
x: Input image tensor
417-
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
418+
indices: Take last n blocks if an int, if is a sequence, select by matching indices
418419
return_prefix_tokens: Return both prefix and spatial intermediate tokens
419420
norm: Apply norm layer to all intermediates
420421
stop_early: Stop iterating over blocks when last desired intermediate hit
421422
output_fmt: Shape of intermediate feature outputs
422-
features_only: Only return intermediate features
423+
intermediates_only: Only return intermediate features
423424
Returns:
424425
425426
"""
426427
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
427428
reshape = output_fmt == 'NCHW'
428429
intermediates = []
429-
num_blocks = len(self.blocks)
430-
if n is None:
431-
n = num_blocks
432-
take_indices, max_index = feature_take_indices(n, num_blocks)
430+
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
433431

434432
# forward pass
435433
B, _, height, width = x.shape
@@ -455,16 +453,14 @@ def forward_intermediates(
455453
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
456454
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
457455
if reshape:
458-
# reshape == True => BCHW output format
459-
patch_size = self.patch_embed.patch_size
460-
H = int(math.ceil(height / patch_size[0]))
461-
W = int(math.ceil(width / patch_size[1]))
456+
# reshape to BCHW output format
457+
H, W = self.patch_embed.dynamic_feat_size((height, width))
462458
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
463459
if not torch.jit.is_scripting() and return_prefix_tokens:
464460
# return_prefix not support in torchscript due to poor type handling
465461
intermediates = list(zip(intermediates, prefix_tokens))
466462

467-
if features_only:
463+
if intermediates_only:
468464
return intermediates
469465

470466
x = self.norm(x)
@@ -479,13 +475,14 @@ def prune_intermediate_layers(
479475
):
480476
""" Prune layers not required for specified intermediates.
481477
"""
482-
take_indices, max_index = feature_take_indices(n, len(self.blocks))
478+
take_indices, max_index = feature_take_indices(len(self.blocks), n)
483479
self.blocks = self.blocks[:max_index + 1] # truncate blocks
484480
if prune_norm:
485481
self.norm = nn.Identity()
486482
if prune_head:
487483
self.fc_norm = nn.Identity()
488484
self.head = nn.Identity()
485+
return take_indices
489486

490487
def forward_features(self, x):
491488
x = self.patch_embed(x)

timm/models/eva.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ def __init__(
424424
**embed_args,
425425
)
426426
num_patches = self.patch_embed.num_patches
427+
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
427428

428429
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
429430

@@ -470,7 +471,7 @@ def __init__(
470471
)
471472
for i in range(depth)])
472473
self.feature_info = [
473-
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
474+
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
474475

475476
use_fc_norm = self.global_pool == 'avg'
476477
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
@@ -564,30 +565,27 @@ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
564565
def forward_intermediates(
565566
self,
566567
x: torch.Tensor,
567-
n: Optional[Union[int, List[int], Tuple[int]]] = None,
568+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
568569
return_prefix_tokens: bool = False,
569570
norm: bool = False,
570571
stop_early: bool = True,
571572
output_fmt: str = 'NCHW',
572-
features_only: bool = False,
573+
intermediates_only: bool = False,
573574
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
574575
""" Forward features that returns intermediates.
575576
Args:
576577
x: Input image tensor
577-
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
578+
indices: Take last n blocks if an int, if is a sequence, select by matching indices
578579
return_prefix_tokens: Return both prefix and spatial intermediate tokens
579580
norm: Apply norm layer to all intermediates
580581
stop_early: Stop iterating over blocks when last desired intermediate hit
581582
output_fmt: Shape of intermediate feature outputs
582-
features_only: Only return intermediate features
583+
intermediates_only: Only return intermediate features
583584
"""
584585
assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.'
585586
reshape = output_fmt == 'NCHW'
586587
intermediates = []
587-
num_blocks = len(self.blocks)
588-
if n is None:
589-
n = num_blocks
590-
take_indices, max_index = feature_take_indices(n, num_blocks)
588+
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
591589

592590
# forward pass
593591
B, _, height, width = x.shape
@@ -608,16 +606,14 @@ def forward_intermediates(
608606
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
609607
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
610608
if reshape:
611-
# reshape == True => BCHW output format
612-
patch_size = self.patch_embed.patch_size
613-
H = int(math.ceil(height / patch_size[0]))
614-
W = int(math.ceil(width / patch_size[1]))
609+
# reshape to BCHW output format
610+
H, W = self.patch_embed.dynamic_feat_size((height, width))
615611
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
616612
if not torch.jit.is_scripting() and return_prefix_tokens:
617613
# return_prefix not support in torchscript due to poor type handling
618614
intermediates = list(zip(intermediates, prefix_tokens))
619615

620-
if features_only:
616+
if intermediates_only:
621617
return intermediates
622618

623619
x = self.norm(x)
@@ -632,13 +628,14 @@ def prune_intermediate_layers(
632628
):
633629
""" Prune layers not required for specified intermediates.
634630
"""
635-
take_indices, max_index = feature_take_indices(n, len(self.blocks))
631+
take_indices, max_index = feature_take_indices(len(self.blocks), n)
636632
self.blocks = self.blocks[:max_index + 1] # truncate blocks
637633
if prune_norm:
638634
self.norm = nn.Identity()
639635
if prune_head:
640636
self.fc_norm = nn.Identity()
641637
self.head = nn.Identity()
638+
return take_indices
642639

643640
def forward_features(self, x):
644641
x = self.patch_embed(x)

0 commit comments

Comments
 (0)