Skip to content

Commit 7fb4341

Browse files
committed
initial commit
1 parent bdc5d98 commit 7fb4341

16 files changed

+1883
-2
lines changed

Pipfile

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
[[source]]
2+
url = "https://pypi.org/simple"
3+
verify_ssl = true
4+
name = "pypi"
5+
6+
[packages]
7+
av = "*"
8+
pandas = "*"
9+
torch = "==1.7.1"
10+
torchvision = "==0.8.2"
11+
hydra-core = "*"
12+
tensorboard = "==2.3.0"
13+
logzero = "*"
14+
coloredlogs = "*"
15+
hydra-colorlog = "*"
16+
tqdm = "*"
17+
scikit-video = "*"
18+
hydra = "*"
19+
fvcore = "*"
20+
21+
[dev-packages]
22+
isort = "*"
23+
ipdb = "*"
24+
black = "*"
25+
vulture = "*"
26+
27+
[requires]
28+
python_version = "3.7"
29+
30+
[pipenv]
31+
allow_prereleases = true

Pipfile.lock

+923
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

+55-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,55 @@
1-
# Action-Recognition-CNN-LSTM
2-
Action recognition tutorial using UCF-101 dataset.
1+
# Action Recognition in Video
2+
3+
This repo will serve as a playground where I investigate different approaches to solving the problem of action recognition in video.
4+
5+
I will mainly use the [UCF-101 dataset](https://www.crcv.ucf.edu/data/UCF101.php).
6+
7+
<p align="center">
8+
<img src="assets/crawling.gif" width="400"\>
9+
</p>
10+
11+
## Setup
12+
13+
```
14+
$ cd data/
15+
$ bash download_ucf101.sh # Downloads the UCF-101 dataset (~7.2 GB)
16+
$ unrar x UCF101.rar # Unrars dataset
17+
$ unzip ucfTrainTestlist.zip # Unzip train / test split
18+
$ python3 extract_frames.py # Extracts frames from the video (~26.2 GB, go grab a coffee for this)
19+
```
20+
21+
## ConvLSTM
22+
23+
The only approach investigated so far. Enables action recognition in video by a bi-directional LSTM operating on frame embeddings extracted by a pre-trained ResNet-152 (ImageNet).
24+
25+
The model is composed of:
26+
* A convolutional feature extractor (ResNet-152) which provides a latent representation of video frames
27+
* A bi-directional LSTM classifier which based on the latent representation of the video predicts the activity depicted
28+
29+
I have made a trained model available [here](https://drive.google.com/open?id=1GlpN0m9uLbI9dg1ARbW9hDEf-VWe4Asl).
30+
31+
### Train
32+
33+
```
34+
$ python3 train.py --dataset_path data/UCF-101-frames/ \
35+
--split_path data/ucfTrainTestlist \
36+
--num_epochs 200 \
37+
--sequence_length 40 \
38+
--img_dim 112 \
39+
--latent_dim 512
40+
```
41+
42+
### Test on Video
43+
44+
```
45+
$ python3 test_on_video.py --video_path data/UCF-101/SoccerPenalty/v_SoccerPenalty_g01_c01.avi \
46+
--checkpoint_model model_checkpoints/ConvLSTM_150.pth
47+
```
48+
49+
<p align="center">
50+
<img src="assets/penalty.gif" width="400"\>
51+
</p>
52+
53+
### Results
54+
55+
The model reaches a classification accuracy of **91.27%** accuracy on a randomly sampled test set, composed of 20% of the total amount of video sequences from UCF-101. Will re-train this model on the offical train / test splits and post results as soon as I have time.

configs/debug/train_debug.yaml

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# @package _group_
2+
3+
# train
4+
train:
5+
num_epochs: 2
6+
batch_size: 1
7+
sequence_length: 2
8+
img_dim: 112
9+
num_workers: 0

configs/default.yaml

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
defaults:
2+
- hydra/job_logging: colorlog
3+
- hydra/hydra_logging: colorlog
4+
5+
hydra:
6+
run:
7+
dir: ./outputs
8+
output_subdir: ./configs/${now:%Y-%m-%d}/${now:%H-%M-%S}
9+
job:
10+
name: log_${now:%Y-%m-%d}_${now:%H-%M-%S}
11+
12+
# datasets
13+
dataset:
14+
root: '/mnt/nfs/kuroyanagi/clones/Action-Recognition-CNN-LSTM/data'
15+
name: 'UCF-101'
16+
frames: 'UCF-101-frames'
17+
split_file: 'ucfTrainTestlist'
18+
split_number: 1
19+
20+
# train
21+
train:
22+
num_epochs: 100
23+
batch_size: 16
24+
sequence_length: 40
25+
image_height: 224
26+
image_width: 224
27+
channels: 3
28+
latent_dim: 512
29+
lstm_layers: 1
30+
hidden_dim: 1024
31+
bidirectional: True
32+
attention: True
33+
num_workers: 4
34+
checkpoint_model: ''
35+
checkpoint_interval: 5
36+
checkpoints_dir: 'checkpoints'
37+
tensorboard_dir: 'logs'
38+
resume: True
39+
40+
# test or test_on_video
41+
test:
42+
num_classes: 101
43+
batch_size: 16
44+
sequence_length: 40
45+
image_height: 224
46+
image_width: 224
47+
channels: 3
48+
latent_dim: 512
49+
lstm_layers: 1
50+
hidden_dim: 1024
51+
bidirectional: True
52+
attention: True
53+
num_workers: 4
54+
checkpoint_model: '/mnt/nfs/kuroyanagi/clones/Action-Recognition-CNN-LSTM/experiments/exp01/model_checkpoints/ConvLSTM_45.pth'
55+
video_name: 'BabyCrawling/v_BabyCrawling_g01_c01.avi'

configs/experiments/test_exp01.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# @package _group_
2+
3+
# test
4+
test:
5+
checkpoint_model: '/mnt/nfs/kuroyanagi/clones/Action-Recognition-CNN-LSTM/experiments/exp01/model_checkpoints/ConvLSTM_45.pth'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# @package _group_
2+
3+
# input video
4+
test:
5+
video_name: 'ApplyEyeMakeup/v_ApplyEyeMakeup_g25_c07.avi'

configs/experiments/train_exp01.yaml

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# @package _group_
2+
3+
# train
4+
train:
5+
batch_size: 8
6+
num_epochs: 10
7+
sequence_length: 20
8+
checkpoint_interval: 5
9+
num_workers: 4

data/check_extract_frames.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
import glob
3+
import argparse
4+
5+
parser = argparse.ArgumentParser()
6+
opt = parser.parse_args()
7+
opt.dataset_frames_path = 'UCF-101-frames'
8+
9+
video_frame_paths = glob.glob(os.path.join(opt.dataset_frames_path, "*", "*"))
10+
11+
for i, video_frame_path in enumerate(video_frame_paths):
12+
video_frame_len = len(glob.glob(os.path.join(video_frame_path, "*")))
13+
if(video_frame_len==0):
14+
print(i, video_frame_path)
15+
16+
# 49 UCF-101-frames/PlayingGuitar/v_PlayingGuitar_g21_c02

data/download_ucf101.sh

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
3+
# Downloads the UCF-101 dataset
4+
wget --no-check-certificate https://www.crcv.ucf.edu/data/UCF101/UCF101.rar
5+
wget --no-check-certificate https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip
6+
# Unzip the UCF-101 dataset
7+
unrar x UCF101.rar
8+
unzip UCF101TrainTestSplits-RecognitionTask.zip

data/extract_frames.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
Helper script for extracting frames from the UCF-101 dataset
3+
"""
4+
5+
import av
6+
import glob
7+
import os
8+
import time
9+
import tqdm
10+
import datetime
11+
import argparse
12+
13+
14+
def extract_frames(video_path):
15+
frames = []
16+
video = av.open(video_path)
17+
for frame in video.decode(0):
18+
yield frame.to_image()
19+
20+
21+
prev_time = time.time()
22+
if __name__ == "__main__":
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument("--dataset_path", type=str, default="UCF-101", help="Path to UCF-101 dataset")
25+
opt = parser.parse_args()
26+
print(opt)
27+
28+
time_left = 0
29+
video_paths = glob.glob(os.path.join(opt.dataset_path, "*", "*.avi"))
30+
for i, video_path in enumerate(video_paths):
31+
sequence_type, sequence_name = video_path.split(".avi")[0].split("/")[-2:]
32+
sequence_path = os.path.join(f"{opt.dataset_path}-frames", sequence_type, sequence_name)
33+
34+
if os.path.exists(sequence_path):
35+
continue
36+
37+
os.makedirs(sequence_path, exist_ok=True)
38+
39+
# Extract frames
40+
for j, frame in enumerate(
41+
tqdm.tqdm(
42+
extract_frames(video_path),
43+
desc=f"[{i}/{len(video_paths)}] {sequence_name} : ETA {time_left}",
44+
)
45+
):
46+
frame.save(os.path.join(sequence_path, f"{j}.jpg"))
47+
48+
# Determine approximate time left
49+
videos_left = len(video_paths) - (i + 1)
50+
time_left = datetime.timedelta(seconds=videos_left * (time.time() - prev_time))
51+
prev_time = time.time()

0 commit comments

Comments
 (0)