Skip to content

Commit 717528a

Browse files
authored
refactor codebase, add timesformer support, improve tests (#24)
* refactor codebase, add timesformer support, improve tests * reformat * reformat * update workflow order * fix a test, add styling script * fix readme
1 parent 8b85f91 commit 717528a

14 files changed

+263
-51
lines changed

.github/workflows/ci.yml

+6-3
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@ jobs:
6464
if: matrix.operating-system == 'macos-latest'
6565
run: pip install torch==${{ matrix.torch-version }}
6666

67-
- name: Install Pytorchvideo from main branch
68-
run: pip install git+https://github.com/facebookresearch/pytorchvideo.git
69-
7067
- name: Lint with flake8, black and isort
7168
run: |
7269
pip install .[dev]
@@ -77,6 +74,12 @@ jobs:
7774
# exit-zero treats all errors as warnings. Allowed max line length is 120.
7875
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=120 --statistics
7976
77+
- name: Install Pytorchvideo from main branch
78+
run: pip install git+https://github.com/facebookresearch/pytorchvideo.git
79+
80+
- name: Install HF/Transformers from main branch
81+
run: pip install -U git+https://github.com/huggingface/transformers.git
82+
8083
- name: Install video-transformers package from local setup.py
8184
run: >
8285
pip install .

.github/workflows/package_testing.yml

+3
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ jobs:
6666
- name: Install Pytorchvideo from main branch
6767
run: pip install git+https://github.com/facebookresearch/pytorchvideo.git
6868

69+
- name: Install HF/Transformers from main branch
70+
run: pip install -U git+https://github.com/huggingface/transformers.git
71+
6972
- name: Install latest video-transformers package
7073
run: >
7174
pip install --upgrade --force-reinstall video-transformers[test]

README.md

+66-7
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ and supports:
4444
conda install pytorch=1.11.0 torchvision=0.12.0 cudatoolkit=11.3 -c pytorch
4545
```
4646

47-
- Install pytorchvideo from main branch:
47+
- Install pytorchvideo and transformers from main branch:
4848

4949
```bash
5050
pip install git+https://github.com/facebookresearch/pytorchvideo.git
51+
pip install git+https://github.com/huggingface/transformers.git
5152
```
5253

5354
- Install `video-transformers`:
@@ -83,7 +84,48 @@ val_root
8384
...
8485
```
8586

86-
- Fine-tune CVT (from HuggingFace) + Transformer based video classifier:
87+
- Fine-tune Timesformer (from HuggingFace) video classifier:
88+
89+
```python
90+
from torch.optim import AdamW
91+
from video_transformers import VideoModel
92+
from video_transformers.backbones.transformers import TransformersBackbone
93+
from video_transformers.data import VideoDataModule
94+
from video_transformers.heads import LinearHead
95+
from video_transformers.trainer import trainer_factory
96+
from video_transformers.utils.file import download_ucf6
97+
98+
backbone = TransformersBackbone("facebook/timesformer-base-finetuned-k400", num_unfrozen_stages=1)
99+
100+
download_ucf6("./")
101+
datamodule = VideoDataModule(
102+
train_root="ucf6/train",
103+
val_root="ucf6/val",
104+
batch_size=4,
105+
num_workers=4,
106+
num_timesteps=8,
107+
preprocess_input_size=224,
108+
preprocess_clip_duration=1,
109+
preprocess_means=backbone.mean,
110+
preprocess_stds=backbone.std,
111+
preprocess_min_short_side=256,
112+
preprocess_max_short_side=320,
113+
preprocess_horizontal_flip_p=0.5,
114+
)
115+
116+
head = LinearHead(hidden_size=backbone.num_features, num_classes=datamodule.num_classes)
117+
model = VideoModel(backbone, head)
118+
119+
optimizer = AdamW(model.parameters(), lr=1e-4)
120+
121+
Trainer = trainer_factory("single_label_classification")
122+
trainer = Trainer(datamodule, model, optimizer=optimizer, max_epochs=8)
123+
124+
trainer.fit()
125+
126+
```
127+
128+
- Fine-tune ConvNeXT (from HuggingFace) + Transformer based video classifier:
87129

88130
```python
89131
from torch.optim import AdamW
@@ -95,7 +137,7 @@ from video_transformers.necks import TransformerNeck
95137
from video_transformers.trainer import trainer_factory
96138
from video_transformers.utils.file import download_ucf6
97139

98-
backbone = TimeDistributed(TransformersBackbone("microsoft/cvt-13", num_unfrozen_stages=0))
140+
backbone = TimeDistributed(TransformersBackbone("facebook/convnext-small-224", num_unfrozen_stages=1))
99141
neck = TransformerNeck(
100142
num_features=backbone.num_features,
101143
num_timesteps=8,
@@ -137,18 +179,18 @@ trainer.fit()
137179

138180
```
139181

140-
- Fine-tune MobileViT (from Timm) + GRU based video classifier:
182+
- Fine-tune Resnet18 (from HuggingFace) + GRU based video classifier:
141183

142184
```python
143185
from video_transformers import TimeDistributed, VideoModel
144-
from video_transformers.backbones.timm import TimmBackbone
186+
from video_transformers.backbones.transformers import TransformersBackbone
145187
from video_transformers.data import VideoDataModule
146188
from video_transformers.heads import LinearHead
147189
from video_transformers.necks import GRUNeck
148190
from video_transformers.trainer import trainer_factory
149191
from video_transformers.utils.file import download_ucf6
150192

151-
backbone = TimeDistributed(TimmBackbone("mobilevitv2_100", num_unfrozen_stages=0))
193+
backbone = TimeDistributed(TransformersBackbone("microsoft/resnet-18", num_unfrozen_stages=1))
152194
neck = GRUNeck(num_features=backbone.num_features, hidden_size=128, num_layers=2, return_last=True)
153195

154196
download_ucf6("./")
@@ -188,7 +230,7 @@ from video_transformers import VideoModel
188230

189231
model = VideoModel.from_pretrained(model_name_or_path)
190232

191-
model.predict(video_path="video.mp4")
233+
model.predict(video_or_folder_path="video.mp4")
192234
>> [{'filename': "video.mp4", 'predictions': {'class1': 0.98, 'class2': 0.02}}]
193235
```
194236

@@ -277,3 +319,20 @@ from video_transformers import VideoModel
277319
model = VideoModel.from_pretrained("runs/exp/checkpoint")
278320
model.to_gradio(examples=['video.mp4'], export_dir="runs/exports/", export_filename="app.py")
279321
```
322+
323+
324+
## Contributing
325+
326+
Before opening a PR:
327+
328+
- Install required development packages:
329+
330+
```bash
331+
pip install -e ."[dev]"
332+
```
333+
334+
- Reformat with black and isort:
335+
336+
```bash
337+
python -m tests.run_code_style format
338+
```

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
accelerate>=0.14.0,<0.15.0
22
evaluate>=0.3.0,<0.4.0
3-
transformers>=4.24.0,<4.25.0
3+
transformers>=4.25.0
44
timm>=0.6.12,<0.7.0
55
click==8.0.4
66
balanced-loss

tests/run_code_style.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import sys
2+
3+
from tests.utils import shell, validate_and_exit
4+
5+
if __name__ == "__main__":
6+
arg = sys.argv[1]
7+
8+
if arg == "check":
9+
sts_flake = shell("flake8 . --config setup.cfg --select=E9,F63,F7,F82")
10+
sts_isort = shell("isort . --check --settings pyproject.toml")
11+
sts_black = shell("black . --check --config pyproject.toml")
12+
validate_and_exit(flake8=sts_flake, isort=sts_isort, black=sts_black)
13+
elif arg == "format":
14+
sts_isort = shell("isort . --settings pyproject.toml")
15+
sts_black = shell("black . --config pyproject.toml")
16+
validate_and_exit(isort=sts_isort, black=sts_black)

tests/test_auto_backbone.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ def test_transformers_backbone(self):
88
from video_transformers import AutoBackbone
99

1010
config = {
11-
"framework": {"name": "timm"},
11+
"framework": {"name": "transformers"},
1212
"type": "2d_backbone",
13-
"model_name": "mobilevitv2_100",
13+
"model_name": "microsoft/resnet-18",
1414
"num_timesteps": 8,
1515
}
1616
batch_size = 2
@@ -20,23 +20,21 @@ def test_transformers_backbone(self):
2020
output = backbone(input)
2121
self.assertEqual(output.shape, (batch_size, config["num_timesteps"], backbone.num_features))
2222

23-
def test_timm_backbone(self):
24-
import torch
25-
23+
def test_from_transformers(self):
2624
from video_transformers import AutoBackbone
2725

28-
config = {
29-
"framework": {"name": "transformers"},
30-
"type": "2d_backbone",
31-
"model_name": "microsoft/cvt-13",
32-
"num_timesteps": 8,
33-
}
34-
batch_size = 2
35-
36-
backbone = AutoBackbone.from_config(config)
37-
input = torch.randn(batch_size, 3, config["num_timesteps"], 224, 224)
38-
output = backbone(input)
39-
self.assertEqual(output.shape, (batch_size, config["num_timesteps"], backbone.num_features))
26+
backbone = AutoBackbone.from_transformers("facebook/timesformer-base-finetuned-k400")
27+
assert backbone.model_name == "facebook/timesformer-base-finetuned-k400"
28+
backbone = AutoBackbone.from_transformers("facebook/timesformer-base-finetuned-k600")
29+
assert backbone.model_name == "facebook/timesformer-base-finetuned-k600"
30+
backbone = AutoBackbone.from_transformers("facebook/timesformer-hr-finetuned-k400")
31+
assert backbone.model_name == "facebook/timesformer-hr-finetuned-k400"
32+
backbone = AutoBackbone.from_transformers("facebook/timesformer-hr-finetuned-k600")
33+
assert backbone.model_name == "facebook/timesformer-hr-finetuned-k600"
34+
backbone = AutoBackbone.from_transformers("facebook/timesformer-base-finetuned-ssv2")
35+
assert backbone.model_name == "facebook/timesformer-base-finetuned-ssv2"
36+
backbone = AutoBackbone.from_transformers("facebook/timesformer-hr-finetuned-ssv2")
37+
assert backbone.model_name == "facebook/timesformer-hr-finetuned-ssv2"
4038

4139

4240
if __name__ == "__main__":

tests/test_auto_head.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
class TestAutoHead(unittest.TestCase):
5-
def test_liear_head(self):
5+
def test_linear_head(self):
66
import torch
77

88
from video_transformers import AutoHead
@@ -20,6 +20,22 @@ def test_liear_head(self):
2020
output = head(input)
2121
self.assertEqual(output.shape, (batch_size, config["num_classes"]))
2222

23+
def test_from_transformers(self):
24+
from video_transformers import AutoHead
25+
26+
linear_head = AutoHead.from_transformers("facebook/timesformer-base-finetuned-k400")
27+
assert linear_head.num_classes == 400
28+
linear_head = AutoHead.from_transformers("facebook/timesformer-base-finetuned-k600")
29+
assert linear_head.num_classes == 600
30+
linear_head = AutoHead.from_transformers("facebook/timesformer-hr-finetuned-k400")
31+
assert linear_head.num_classes == 400
32+
linear_head = AutoHead.from_transformers("facebook/timesformer-hr-finetuned-k600")
33+
assert linear_head.num_classes == 600
34+
linear_head = AutoHead.from_transformers("facebook/timesformer-base-finetuned-ssv2")
35+
assert linear_head.num_classes == 174
36+
linear_head = AutoHead.from_transformers("facebook/timesformer-hr-finetuned-ssv2")
37+
assert linear_head.num_classes == 174
38+
2339

2440
if __name__ == "__main__":
2541
unittest.main()

tests/utils.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import shutil
3+
import sys
4+
5+
6+
def shell(command, exit_status=0):
7+
"""
8+
Run command through shell and return exit status if exit status of command run match with given exit status.
9+
10+
Args:
11+
command: (str) Command string which runs through system shell.
12+
exit_status: (int) Expected exit status of given command run.
13+
14+
Returns: actual_exit_status
15+
16+
"""
17+
actual_exit_status = os.system(command)
18+
if actual_exit_status == exit_status:
19+
return 0
20+
return actual_exit_status
21+
22+
23+
def validate_and_exit(expected_out_status=0, **kwargs):
24+
if all([arg == expected_out_status for arg in kwargs.values()]):
25+
# Expected status, OK
26+
sys.exit(0)
27+
else:
28+
# Failure
29+
print_console_centered("Summary Results")
30+
fail_count = 0
31+
for component, exit_status in kwargs.items():
32+
if exit_status != expected_out_status:
33+
print(f"{component} failed.")
34+
fail_count += 1
35+
print_console_centered(f"{len(kwargs)-fail_count} success, {fail_count} failure")
36+
sys.exit(1)
37+
38+
39+
def print_console_centered(text: str, fill_char="="):
40+
w, _ = shutil.get_terminal_size((80, 20))
41+
print(f" {text} ".center(w, fill_char))

video_transformers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from video_transformers.auto.neck import AutoNeck
44
from video_transformers.modeling import TimeDistributed, VideoModel
55

6-
__version__ = "0.0.7"
6+
__version__ = "0.0.8"

video_transformers/auto/backbone.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,22 @@ def from_config(cls, config: Dict) -> Union[Backbone, TimeDistributed]:
1515
backbone_type = config.get("type")
1616
backbone_model_name = config.get("model_name")
1717

18-
if backbone_framework["name"] == "transformers":
19-
from video_transformers.backbones.transformers import TransformersBackbone
18+
from video_transformers.backbones.transformers import TransformersBackbone
2019

21-
backbone = TransformersBackbone(model_name=backbone_model_name)
22-
elif backbone_framework["name"] == "timm":
23-
from video_transformers.backbones.timm import TimmBackbone
24-
25-
backbone = TimmBackbone(model_name=backbone_model_name)
26-
else:
27-
raise ValueError(f"Unknown framework {backbone_framework}")
20+
backbone = TransformersBackbone(model_name=backbone_model_name)
2821

2922
if backbone_type == "2d_backbone":
3023
from video_transformers.modeling import TimeDistributed
3124

3225
backbone = TimeDistributed(backbone)
3326
return backbone
27+
28+
@classmethod
29+
def from_transformers(cls, name_or_path: str) -> Union[Backbone, TimeDistributed]:
30+
from video_transformers.backbones.transformers import TransformersBackbone
31+
32+
backbone = TransformersBackbone(model_name=name_or_path)
33+
34+
if backbone.type == "2d_backbone":
35+
raise ValueError("2D backbones are not supported for from_transformers method.")
36+
return backbone

video_transformers/auto/head.py

+12
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,15 @@ def from_config(cls, config: Dict):
1818
return LinearHead(hidden_size, num_classes, dropout_p)
1919
else:
2020
raise ValueError(f"Unsupported head class name: {head_class_name}")
21+
22+
@classmethod
23+
def from_transformers(cls, name_or_path: str):
24+
from transformers import AutoModelForVideoClassification
25+
26+
from video_transformers.heads import LinearHead
27+
28+
model = AutoModelForVideoClassification.from_pretrained(name_or_path)
29+
linear_head = LinearHead(model.classifier.in_features, model.classifier.out_features)
30+
linear_head.linear.weight = model.classifier.weight
31+
linear_head.linear.bias = model.classifier.bias
32+
return linear_head

0 commit comments

Comments
 (0)