Skip to content

Commit 79d26b3

Browse files
authored
Add files via upload
1 parent 79487b9 commit 79d26b3

File tree

6 files changed

+2058
-2
lines changed

6 files changed

+2058
-2
lines changed

LICENSE

+1,401
Large diffs are not rendered by default.

README.md

+164-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,164 @@
1-
# CrowdCounting-P2PNet
2-
The official codes for the ICCV2021 Oral presentation "Rethinking Counting and Localization in Crowds: A Purely Point-Based Framework"
1+
# P2PNet (ICCV2021 Oral Presentation)
2+
3+
This repository contains codes for the official implementation in PyTorch of **P2PNet** as described in [Rethinking Counting and Localization in Crowds: A Purely Point-Based Framework](https://arxiv.org/abs/2107.12746).
4+
5+
An brief introduction of P2PNet can be found at [机器之心 (almosthuman)](https://mp.weixin.qq.com/s?__biz=MzA3MzI4MjgzMw==&mid=2650827826&idx=3&sn=edd3d66444130fb34a59d08fab618a9e&chksm=84e5a84cb392215a005a3b3424f20a9d24dc525dcd933960035bf4b6aa740191b5ecb2b7b161&mpshare=1&scene=1&srcid=1004YEOC7HC9daYRYeUio7Xn&sharer_sharetime=1633675738338&sharer_shareid=7d375dccd3b2f9eec5f8b27ee7c04883&version=3.1.16.5505&platform=win#rd).
6+
7+
The codes is tested with PyTorch 1.5.0. It may not run with other versions.
8+
9+
## Visualized demos for P2PNet
10+
<img src="vis/congested1.png" width="1000"/>
11+
<img src="vis/congested2.png" width="1000"/>
12+
<img src="vis/congested3.png" width="1000"/>
13+
14+
## The network
15+
The overall architecture of the P2PNet. Built upon the VGG16, it firstly introduce an upsampling path to obtain fine-grained feature map.
16+
Then it exploits two branches to simultaneously predict a set of point proposals and their confidence scores.
17+
18+
<img src="vis/net.png" width="1000"/>
19+
20+
## Comparison with state-of-the-art methods
21+
The P2PNet achieved state-of-the-art performance on several challenging datasets with various densities.
22+
23+
| Methods | Venue | SHTechPartA <br> MAE/MSE |SHTechPartB <br> MAE/MSE | UCF_CC_50 <br> MAE/MSE | UCF_QNRF <br> MAE/MSE |
24+
|:----:|:----:|:----:|:----:|:----:|:----:|
25+
CAN | CVPR'19 | 62.3/100.0 | 7.8/12.2 | 212.2/**243.7** | 107.0/183.0 |
26+
Bayesian+ | ICCV'19 | 62.8/101.8 | 7.7/12.7 | 229.3/308.2 | 88.7/154.8 |
27+
S-DCNet | ICCV'19 | 58.3/95.0 | 6.7/10.7 | 204.2/301.3 | 104.4/176.1 |
28+
SANet+SPANet | ICCV'19 | 59.4/92.5 | 6.5/**9.9** | 232.6/311.7 | -/- |
29+
DUBNet | AAAI'20 | 64.6/106.8 | 7.7/12.5 | 243.8/329.3 | 105.6/180.5 |
30+
SDANet | AAAI'20 | 63.6/101.8 | 7.8/10.2 | 227.6/316.4 | -/- |
31+
ADSCNet | CVPR'20 | <u>55.4</u>/97.7 | <u>6.4</u>/11.3 | 198.4/267.3 | **71.3**/**132.5**|
32+
ASNet | CVPR'20 | 57.78/<u>90.13</u> | -/- | <u>174.84</u>/<u>251.63</u> | 91.59/159.71 |
33+
AMRNet | ECCV'20 | 61.59/98.36 | 7.02/11.00 | 184.0/265.8 | 86.6/152.2 |
34+
AMSNet | ECCV'20 | 56.7/93.4 | 6.7/10.2 | 208.4/297.3 | 101.8/163.2|
35+
DM-Count | NeurIPS'20 | 59.7/95.7 | 7.4/11.8 | 211.0/291.5 | 85.6/<u>148.3</u>|
36+
**Ours** |- | **52.74**/**85.06** | **6.25**/**9.9** | **172.72**/256.18 | <u>85.32</u>/154.5 |
37+
38+
Comparison on the [NWPU-Crowd](https://www.crowdbenchmark.com/resultdetail.html?rid=81) dataset.
39+
40+
| Methods | MAE[O] |MSE[O] | MAE[L] | MAE[S] |
41+
|:----:|:----:|:----:|:----:|:----:|
42+
MCNN | 232.5|714.6 | 220.9|1171.9 |
43+
SANet | 190.6 | 491.4 | 153.8 | 716.3|
44+
CSRNet | 121.3 | 387.8 | 112.0 | <u>522.7</u> |
45+
PCC-Net | 112.3 | 457.0 | 111.0 | 777.6 |
46+
CANNet | 110.0 | 495.3 | 102.3 | 718.3|
47+
Bayesian+ | 105.4 | 454.2 | 115.8 | 750.5 |
48+
S-DCNet | 90.2 | 370.5 | **82.9** | 567.8 |
49+
DM-Count | <u>88.4</u> | 388.6 | 88.0 | **498.0** |
50+
**Ours** | **77.44**|**362** | <u>83.28</u>| 553.92 |
51+
52+
The overall performance for both counting and localization.
53+
54+
|nAP$_{\delta}$|SHTechPartA| SHTechPartB | UCF_CC_50 | UCF_QNRF | NWPU_Crowd |
55+
|:----:|:----:|:----:|:----:|:----:|:----:|
56+
$\delta=0.05$ | 10.9\% | 23.8\% | 5.0\% | 5.9\% | 12.9\% |
57+
$\delta=0.25$ | 70.3\% | 84.2\% | 54.5\% | 55.4\% | 71.3\% |
58+
$\delta=0.50$ | 90.1\% | 94.1\% | 88.1\% | 83.2\% | 89.1\% |
59+
$\delta=\{{0.05:0.05:0.50}\}$ | 64.4\% | 76.3\% | 54.3\% | 53.1\% | 65.0\% |
60+
61+
Comparison for the localization performance in terms of F1-Measure on NWPU.
62+
63+
| Method| F1-Measure |Precision| Recall |
64+
|:----:|:----:|:----:|:----:|
65+
FasterRCNN | 0.068 | 0.958 | 0.035 |
66+
TinyFaces | 0.567 | 0.529 | 0.611 |
67+
RAZ | 0.599 | 0.666 | 0.543|
68+
Crowd-SDNet | 0.637 | 0.651 | 0.624 |
69+
PDRNet | 0.653 | 0.675 | 0.633 |
70+
TopoCount | 0.692 | 0.683 | **0.701** |
71+
D2CNet | <u>0.700</u> | **0.741** | 0.662 |
72+
**Ours** |**0.712** | <u>0.729</u> | <u>0.695</u> |
73+
74+
## Installation
75+
* Clone this repo into a directory named P2PNET_ROOT
76+
* Organize your datasets as required
77+
* Install Python dependencies. We use python 3.6.5 and pytorch 1.5.0
78+
```
79+
pip install -r requirements.txt
80+
```
81+
82+
## Organize the counting dataset
83+
We use a list file to collect all the images and their ground truth annotations in a counting dataset. When your dataset is organized as recommended in the following, the format of this list file is defined as:
84+
```
85+
train/scene01/img01.jpg train/scene01/img01.txt
86+
train/scene01/img02.jpg train/scene01/img02.txt
87+
...
88+
train/scene02/img01.jpg train/scene02/img01.txt
89+
```
90+
91+
### Dataset structures:
92+
```
93+
DATA_ROOT/
94+
|->train/
95+
| |->scene01/
96+
| |->scene02/
97+
| |->...
98+
|->test/
99+
| |->scene01/
100+
| |->scene02/
101+
| |->...
102+
|->train.list
103+
|->test.list
104+
```
105+
DATA_ROOT is your path containing the counting datasets.
106+
107+
### Annotations format
108+
For the annotations of each image, we use a single txt file which contains one annotation per line. Note that indexing for pixel values starts at 0. The expected format of each line is:
109+
```
110+
x1 y1
111+
x2 y2
112+
...
113+
```
114+
115+
## Training
116+
117+
The network can be trained using the `train.py` script. For training on SHTechPartA, use
118+
119+
```
120+
CUDA_VISIBLE_DEVICES=0 python train.py --data_root $DATA_ROOT \
121+
--dataset_file SHHA \
122+
--epochs 3500 \
123+
--lr_drop 3500 \
124+
--output_dir ./logs \
125+
--checkpoints_dir ./weights \
126+
--tensorboard_dir ./logs \
127+
--lr 0.0001 \
128+
--lr_backbone 0.00001 \
129+
--batch_size 8 \
130+
--eval_freq 1 \
131+
--gpu_id 0
132+
```
133+
By default, a periodic evaluation will be conducted on the validation set.
134+
135+
## Testing
136+
137+
A trained model (with an MAE of **51.96**) on SHTechPartA is available at "./weights", run the following commands to launch a visualization demo:
138+
139+
```
140+
CUDA_VISIBLE_DEVICES=0 python run_test.py --weight_path ./weights/SHTechA.pth --output_dir ./logs/
141+
```
142+
143+
## Acknowledgements
144+
145+
- Part of codes are borrowed from the [C^3 Framework](https://github.com/gjy3035/C-3-Framework).
146+
- We refer to [DETR](https://github.com/facebookresearch/detr) to implement our matching strategy.
147+
148+
149+
## Citing P2PNet
150+
151+
If you find P2PNet is useful in your project, please consider citing us:
152+
153+
```BibTeX
154+
@inproceedings{song2021rethinking,
155+
title={Rethinking Counting and Localization in Crowds: A Purely Point-Based Framework},
156+
author={Song, Qingyu and Wang, Changan and Jiang, Zhengkai and Wang, Yabiao and Tai, Ying and Wang, Chengjie and Li, Jilin and Huang, Feiyue and Wu, Yang},
157+
journal={Proceedings of the IEEE/CVF International Conference on Computer Vision},
158+
year={2021}
159+
}
160+
```
161+
162+
## Related works from Tencent Youtu Lab
163+
- [AAAI2021] To Choose or to Fuse? Scale Selection for Crowd Counting. ([paper link](https://ojs.aaai.org/index.php/AAAI/article/view/16360) & [codes](https://github.com/TencentYoutuResearch/CrowdCounting-SASNet))
164+
- [ICCV2021] Uniformity in Heterogeneity: Diving Deep into Count Interval Partition for Crowd Counting. ([paper link](https://arxiv.org/abs/2107.12619) & [codes](https://github.com/TencentYoutuResearch/CrowdCounting-UEPNet))

engine.py

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2+
"""
3+
Train and eval functions used in main.py
4+
Mostly copy-paste from DETR (https://github.com/facebookresearch/detr).
5+
"""
6+
import math
7+
import os
8+
import sys
9+
from typing import Iterable
10+
11+
import torch
12+
13+
import util.misc as utils
14+
from util.misc import NestedTensor
15+
import numpy as np
16+
import time
17+
import torchvision.transforms as standard_transforms
18+
import cv2
19+
20+
class DeNormalize(object):
21+
def __init__(self, mean, std):
22+
self.mean = mean
23+
self.std = std
24+
25+
def __call__(self, tensor):
26+
for t, m, s in zip(tensor, self.mean, self.std):
27+
t.mul_(s).add_(m)
28+
return tensor
29+
30+
def vis(samples, targets, pred, vis_dir, des=None):
31+
'''
32+
samples -> tensor: [batch, 3, H, W]
33+
targets -> list of dict: [{'points':[], 'image_id': str}]
34+
pred -> list: [num_preds, 2]
35+
'''
36+
gts = [t['point'].tolist() for t in targets]
37+
38+
pil_to_tensor = standard_transforms.ToTensor()
39+
40+
restore_transform = standard_transforms.Compose([
41+
DeNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
42+
standard_transforms.ToPILImage()
43+
])
44+
# draw one by one
45+
for idx in range(samples.shape[0]):
46+
sample = restore_transform(samples[idx])
47+
sample = pil_to_tensor(sample.convert('RGB')).numpy() * 255
48+
sample_gt = sample.transpose([1, 2, 0])[:, :, ::-1].astype(np.uint8).copy()
49+
sample_pred = sample.transpose([1, 2, 0])[:, :, ::-1].astype(np.uint8).copy()
50+
51+
max_len = np.max(sample_gt.shape)
52+
53+
size = 2
54+
# draw gt
55+
for t in gts[idx]:
56+
sample_gt = cv2.circle(sample_gt, (int(t[0]), int(t[1])), size, (0, 255, 0), -1)
57+
# draw predictions
58+
for p in pred[idx]:
59+
sample_pred = cv2.circle(sample_pred, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
60+
61+
name = targets[idx]['image_id']
62+
# save the visualized images
63+
if des is not None:
64+
cv2.imwrite(os.path.join(vis_dir, '{}_{}_gt_{}_pred_{}_gt.jpg'.format(int(name),
65+
des, len(gts[idx]), len(pred[idx]))), sample_gt)
66+
cv2.imwrite(os.path.join(vis_dir, '{}_{}_gt_{}_pred_{}_pred.jpg'.format(int(name),
67+
des, len(gts[idx]), len(pred[idx]))), sample_pred)
68+
else:
69+
cv2.imwrite(
70+
os.path.join(vis_dir, '{}_gt_{}_pred_{}_gt.jpg'.format(int(name), len(gts[idx]), len(pred[idx]))),
71+
sample_gt)
72+
cv2.imwrite(
73+
os.path.join(vis_dir, '{}_gt_{}_pred_{}_pred.jpg'.format(int(name), len(gts[idx]), len(pred[idx]))),
74+
sample_pred)
75+
76+
# the training routine
77+
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
78+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
79+
device: torch.device, epoch: int, max_norm: float = 0):
80+
model.train()
81+
criterion.train()
82+
metric_logger = utils.MetricLogger(delimiter=" ")
83+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
84+
# iterate all training samples
85+
for samples, targets in data_loader:
86+
samples = samples.to(device)
87+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
88+
# forward
89+
outputs = model(samples)
90+
# calc the losses
91+
loss_dict = criterion(outputs, targets)
92+
weight_dict = criterion.weight_dict
93+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
94+
95+
# reduce all losses
96+
loss_dict_reduced = utils.reduce_dict(loss_dict)
97+
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
98+
for k, v in loss_dict_reduced.items()}
99+
loss_dict_reduced_scaled = {k: v * weight_dict[k]
100+
for k, v in loss_dict_reduced.items() if k in weight_dict}
101+
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
102+
103+
loss_value = losses_reduced_scaled.item()
104+
105+
if not math.isfinite(loss_value):
106+
print("Loss is {}, stopping training".format(loss_value))
107+
print(loss_dict_reduced)
108+
sys.exit(1)
109+
# backward
110+
optimizer.zero_grad()
111+
losses.backward()
112+
if max_norm > 0:
113+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
114+
optimizer.step()
115+
# update logger
116+
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
117+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
118+
# gather the stats from all processes
119+
metric_logger.synchronize_between_processes()
120+
print("Averaged stats:", metric_logger)
121+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
122+
123+
# the inference routine
124+
@torch.no_grad()
125+
def evaluate_crowd_no_overlap(model, data_loader, device, vis_dir=None):
126+
model.eval()
127+
128+
metric_logger = utils.MetricLogger(delimiter=" ")
129+
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
130+
# run inference on all images to calc MAE
131+
maes = []
132+
mses = []
133+
for samples, targets in data_loader:
134+
samples = samples.to(device)
135+
136+
outputs = model(samples)
137+
outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]
138+
139+
outputs_points = outputs['pred_points'][0]
140+
141+
gt_cnt = targets[0]['point'].shape[0]
142+
# 0.5 is used by default
143+
threshold = 0.5
144+
145+
points = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist()
146+
predict_cnt = int((outputs_scores > threshold).sum())
147+
# if specified, save the visualized images
148+
if vis_dir is not None:
149+
vis(samples, targets, [points], vis_dir)
150+
# accumulate MAE, MSE
151+
mae = abs(predict_cnt - gt_cnt)
152+
mse = (predict_cnt - gt_cnt) * (predict_cnt - gt_cnt)
153+
maes.append(float(mae))
154+
mses.append(float(mse))
155+
# calc MAE, MSE
156+
mae = np.mean(maes)
157+
mse = np.sqrt(np.mean(mses))
158+
159+
return mae, mse

requirements.txt

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
torch
2+
torchvision
3+
tensorboardX
4+
easydict
5+
pandas
6+
numpy
7+
scipy
8+
matplotlib
9+
Pillow
10+
opencv-python

0 commit comments

Comments
 (0)