Skip to content

Commit d6b9552

Browse files
authored
Merge pull request #2136 from huggingface/vit_features_only
Exploring vit features_only via new forward_intermediates() API, inspired by #2131
2 parents 24f6d4f + fe3cf54 commit d6b9552

23 files changed

+864
-127
lines changed

tests/test_models.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,16 @@
4747
torch._C._jit_set_profiling_executor(True)
4848
torch._C._jit_set_profiling_mode(False)
4949

50+
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
51+
FEAT_INTER_FILTERS = [
52+
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
53+
]
54+
5055
# transformer models don't support many of the spatial / feature based model functionalities
5156
NON_STD_FILTERS = [
5257
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
53-
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
54-
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
58+
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
59+
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
5560
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
5661
]
5762
NUM_NON_STD = len(NON_STD_FILTERS)
@@ -351,15 +356,46 @@ def test_model_forward_torchscript(model_name, batch_size):
351356

352357
@pytest.mark.features
353358
@pytest.mark.timeout(120)
354-
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
359+
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
355360
@pytest.mark.parametrize('batch_size', [1])
356361
def test_model_forward_features(model_name, batch_size):
357362
"""Run a single forward pass with each model in feature extraction mode"""
358363
model = create_model(model_name, pretrained=False, features_only=True)
359364
model.eval()
360365
expected_channels = model.feature_info.channels()
361366
expected_reduction = model.feature_info.reduction()
362-
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
367+
assert len(expected_channels) >= 3 # all models here should have at least 3 default feat levels
368+
369+
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
370+
if max(input_size) > MAX_FFEAT_SIZE:
371+
pytest.skip("Fixed input size model > limit.")
372+
output_fmt = getattr(model, 'output_fmt', 'NCHW')
373+
feat_axis = get_channel_dim(output_fmt)
374+
spatial_axis = get_spatial_dim(output_fmt)
375+
import math
376+
377+
outputs = model(torch.randn((batch_size, *input_size)))
378+
assert len(expected_channels) == len(outputs)
379+
spatial_size = input_size[-2:]
380+
for e, r, o in zip(expected_channels, expected_reduction, outputs):
381+
assert e == o.shape[feat_axis]
382+
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
383+
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
384+
assert o.shape[0] == batch_size
385+
assert not torch.isnan(o).any()
386+
387+
388+
@pytest.mark.features
389+
@pytest.mark.timeout(120)
390+
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
391+
@pytest.mark.parametrize('batch_size', [1])
392+
def test_model_forward_intermediates_features(model_name, batch_size):
393+
"""Run a single forward pass with each model in feature extraction mode"""
394+
model = create_model(model_name, pretrained=False, features_only=True)
395+
model.eval()
396+
print(model.feature_info.out_indices)
397+
expected_channels = model.feature_info.channels()
398+
expected_reduction = model.feature_info.reduction()
363399

364400
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
365401
if max(input_size) > MAX_FFEAT_SIZE:
@@ -373,6 +409,41 @@ def test_model_forward_features(model_name, batch_size):
373409
assert len(expected_channels) == len(outputs)
374410
spatial_size = input_size[-2:]
375411
for e, r, o in zip(expected_channels, expected_reduction, outputs):
412+
print(o.shape)
413+
assert e == o.shape[feat_axis]
414+
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
415+
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
416+
assert o.shape[0] == batch_size
417+
assert not torch.isnan(o).any()
418+
419+
420+
@pytest.mark.features
421+
@pytest.mark.timeout(120)
422+
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
423+
@pytest.mark.parametrize('batch_size', [1])
424+
def test_model_forward_intermediates(model_name, batch_size):
425+
"""Run a single forward pass with each model in feature extraction mode"""
426+
model = create_model(model_name, pretrained=False)
427+
model.eval()
428+
feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info))
429+
expected_channels = feature_info.channels()
430+
expected_reduction = feature_info.reduction()
431+
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
432+
433+
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
434+
if max(input_size) > MAX_FFEAT_SIZE:
435+
pytest.skip("Fixed input size model > limit.")
436+
output_fmt = getattr(model, 'output_fmt', 'NCHW')
437+
feat_axis = get_channel_dim(output_fmt)
438+
spatial_axis = get_spatial_dim(output_fmt)
439+
import math
440+
441+
output, intermediates = model.forward_intermediates(
442+
torch.randn((batch_size, *input_size)),
443+
)
444+
assert len(expected_channels) == len(intermediates)
445+
spatial_size = input_size[-2:]
446+
for e, r, o in zip(expected_channels, expected_reduction, intermediates):
376447
assert e == o.shape[feat_axis]
377448
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
378449
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1

timm/layers/patch_embed.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Hacked together by / Copyright 2020 Ross Wightman
1010
"""
1111
import logging
12+
import math
1213
from typing import Callable, List, Optional, Tuple, Union
1314

1415
import torch
@@ -65,6 +66,21 @@ def __init__(
6566
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
6667
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
6768

69+
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
70+
if as_scalar:
71+
return max(self.patch_size)
72+
else:
73+
return self.patch_size
74+
75+
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
76+
""" Get grid (feature) size for given image size taking account of dynamic padding.
77+
NOTE: must be torchscript compatible so using fixed tuple indexing
78+
"""
79+
if self.dynamic_img_pad:
80+
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
81+
else:
82+
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
83+
6884
def forward(self, x):
6985
B, C, H, W = x.shape
7086
if self.img_size is not None:
@@ -127,13 +143,13 @@ def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
127143
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")
128144

129145
x = self.proj(x)
130-
grid_size = x.shape[-2:]
146+
feat_size = x.shape[-2:]
131147
if self.flatten:
132148
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
133149
elif self.output_fmt != Format.NCHW:
134150
x = nchw_to(x, self.output_fmt)
135151
x = self.norm(x)
136-
return x, grid_size
152+
return x, feat_size
137153

138154

139155
def resample_patch_embed(

timm/models/_builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch import nn as nn
88
from torch.hub import load_state_dict_from_url
99

10-
from timm.models._features import FeatureListNet, FeatureHookNet
10+
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
1111
from timm.models._features_fx import FeatureGraphNet
1212
from timm.models._helpers import load_state_dict
1313
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
@@ -428,8 +428,12 @@ def build_model_with_cfg(
428428
feature_cls = feature_cls.lower()
429429
if 'hook' in feature_cls:
430430
feature_cls = FeatureHookNet
431+
elif feature_cls == 'dict':
432+
feature_cls = FeatureDictNet
431433
elif feature_cls == 'fx':
432434
feature_cls = FeatureGraphNet
435+
elif feature_cls == 'getter':
436+
feature_cls = FeatureGetterNet
433437
else:
434438
assert False, f'Unknown feature class {feature_cls}'
435439
model = feature_cls(model, **feature_cfg)

0 commit comments

Comments
 (0)