Skip to content

Commit 01f8c64

Browse files
committed
Add streaming video support using VideoPipe (#17, #19)
1 parent 2cacb7d commit 01f8c64

File tree

9 files changed

+579
-0
lines changed

9 files changed

+579
-0
lines changed

assets/videopipe.jpg

485 KB
Loading

demo/VideoPipe/CMakeLists.txt

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# this is the build file for project
2+
# it is autogenerated by the xmake build system.
3+
# do not edit by hand.
4+
5+
# project
6+
cmake_minimum_required(VERSION 3.15.0)
7+
cmake_policy(SET CMP0091 NEW)
8+
project(PipeDemo LANGUAGES CXX CUDA)
9+
10+
# target
11+
add_executable(PipeDemo "")
12+
set_target_properties(PipeDemo PROPERTIES OUTPUT_NAME "PipeDemo")
13+
set_target_properties(PipeDemo PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/build/linux/x86_64/release")
14+
target_include_directories(PipeDemo PRIVATE
15+
/usr/local/tensorrt/include
16+
/home/laugh/Projects/TensorRT-YOLO/include
17+
/home/laugh/Projects/VideoPipe
18+
/usr/local/cuda/include
19+
)
20+
target_include_directories(PipeDemo SYSTEM PRIVATE
21+
/usr/local/include/opencv4
22+
/usr/include/libdrm
23+
/usr/include/gstreamer-1.0
24+
/usr/include/x86_64-linux-gnu
25+
/usr/include/glib-2.0
26+
/usr/lib/x86_64-linux-gnu/glib-2.0/include
27+
)
28+
target_compile_options(PipeDemo PRIVATE
29+
$<$<COMPILE_LANGUAGE:C>:-m64>
30+
$<$<COMPILE_LANGUAGE:CXX>:-m64>
31+
$<$<COMPILE_LANGUAGE:C>:-DNDEBUG>
32+
$<$<COMPILE_LANGUAGE:CXX>:-DNDEBUG>
33+
$<$<COMPILE_LANGUAGE:C>:-pthread>
34+
$<$<COMPILE_LANGUAGE:CXX>:-pthread>
35+
$<$<COMPILE_LANGUAGE:CUDA>:-allow-unsupported-compiler>
36+
$<$<COMPILE_LANGUAGE:CUDA>:-m64>
37+
$<$<COMPILE_LANGUAGE:CUDA>:-rdc=true>
38+
$<$<COMPILE_LANGUAGE:CUDA>:-gencode arch=compute_75,code=sm_75>
39+
)
40+
set_target_properties(PipeDemo PROPERTIES CXX_EXTENSIONS OFF)
41+
target_compile_features(PipeDemo PRIVATE cxx_std_17)
42+
if(MSVC)
43+
target_compile_options(PipeDemo PRIVATE $<$<CONFIG:Release>:-Ox -fp:fast>)
44+
else()
45+
target_compile_options(PipeDemo PRIVATE -O3)
46+
endif()
47+
if(MSVC)
48+
else()
49+
target_compile_options(PipeDemo PRIVATE -fvisibility=hidden)
50+
endif()
51+
if(MSVC)
52+
set_property(TARGET PipeDemo PROPERTY
53+
MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>")
54+
endif()
55+
target_link_libraries(PipeDemo PRIVATE
56+
nvinfer
57+
nvinfer_plugin
58+
nvonnxparser
59+
deploy
60+
video_pipe
61+
tinyexpr
62+
opencv_gapi
63+
opencv_stitching
64+
opencv_aruco
65+
opencv_bgsegm
66+
opencv_bioinspired
67+
opencv_ccalib
68+
opencv_cudabgsegm
69+
opencv_cudafeatures2d
70+
opencv_cudaobjdetect
71+
opencv_cudastereo
72+
opencv_dnn_objdetect
73+
opencv_dnn_superres
74+
opencv_dpm
75+
opencv_face
76+
opencv_freetype
77+
opencv_fuzzy
78+
opencv_hdf
79+
opencv_hfs
80+
opencv_img_hash
81+
opencv_intensity_transform
82+
opencv_line_descriptor
83+
opencv_mcc
84+
opencv_quality
85+
opencv_rapid
86+
opencv_reg
87+
opencv_rgbd
88+
opencv_saliency
89+
opencv_signal
90+
opencv_stereo
91+
opencv_structured_light
92+
opencv_phase_unwrapping
93+
opencv_superres
94+
opencv_surface_matching
95+
opencv_tracking
96+
opencv_highgui
97+
opencv_datasets
98+
opencv_text
99+
opencv_plot
100+
opencv_videostab
101+
opencv_cudaoptflow
102+
opencv_optflow
103+
opencv_cudalegacy
104+
opencv_videoio
105+
opencv_cudawarping
106+
opencv_wechat_qrcode
107+
opencv_xfeatures2d
108+
opencv_shape
109+
opencv_ml
110+
opencv_ximgproc
111+
opencv_video
112+
opencv_xobjdetect
113+
opencv_objdetect
114+
opencv_calib3d
115+
opencv_imgcodecs
116+
opencv_features2d
117+
opencv_dnn
118+
opencv_flann
119+
opencv_xphoto
120+
opencv_photo
121+
opencv_cudaimgproc
122+
opencv_cudafilters
123+
opencv_imgproc
124+
opencv_cudaarithm
125+
opencv_core
126+
opencv_cudev
127+
drm
128+
gstreamer-1.0
129+
gobject-2.0
130+
glib-2.0
131+
cudadevrt
132+
cudart_static
133+
rt
134+
pthread
135+
dl
136+
)
137+
target_link_directories(PipeDemo PRIVATE
138+
/usr/local/tensorrt/lib
139+
/home/laugh/Projects/TensorRT-YOLO/lib
140+
/home/laugh/Projects/VideoPipe/build/libs
141+
/usr/local/cuda/lib64
142+
/usr/local/lib
143+
)
144+
target_link_options(PipeDemo PRIVATE
145+
-m64
146+
)
147+
target_sources(PipeDemo PRIVATE
148+
src/main.cpp
149+
src/vp_trtyolo_detector.cpp
150+
)
151+

demo/VideoPipe/README.en.md

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
English | [简体中文](README.md)
2+
3+
# Video Analysis Example
4+
5+
This example uses the YOLOv8s model to demonstrate how to integrate the TensorRT-YOLO Deploy module into [VideoPipe](https://github.com/sherlockchou86/VideoPipe) for video analysis.
6+
7+
## Model Export
8+
9+
First, download the YOLOv8s model from [YOLOv8s](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8s.pt) and save it to the `workspace` folder.
10+
11+
Next, use the following command to export the model to ONNX format with the EfficientNMS plugin from [EfficientNMS](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin):
12+
13+
```bash
14+
cd workspace
15+
trtyolo export -w yolov8s.pt -v yolov8 -o models -b 2
16+
```
17+
18+
After executing the command above, a file named `yolov8s.onnx` will be generated in the `models` folder. Then, convert the ONNX file to a TensorRT engine using the `trtexec` tool:
19+
20+
```bash
21+
cd workspace
22+
trtexec --onnx=yolov8s.onnx --saveEngine=yolov8s.engine --fp16
23+
```
24+
25+
## Project Execution
26+
27+
Before performing inference, make sure VideoPipe and TensorRT-YOLO have been compiled.
28+
29+
Next, use xmake to compile the project into an executable:
30+
31+
```bash
32+
xmake f -P . --tensorrt=/path/to/your/TensorRT --deploy=/path/to/your/TensorRT-YOLO --videopipe=/path/to/your/VideoPipe
33+
34+
xmake -P . -r
35+
```
36+
37+
After successful compilation, you can directly run the generated executable or use the `xmake run` command for inference:
38+
39+
```bash
40+
xmake run -P . PipeDemo
41+
```
42+
43+
<div align="center">
44+
<p>
45+
<img width="100%" src="../../assets/videopipe.jpg">
46+
</p>
47+
</div>
48+
49+
The above demonstrates the method for performing model inference.

demo/VideoPipe/README.md

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
[English](README.en.md) | 简体中文
2+
3+
# 视频分析示例
4+
5+
本示例以 YOLOv8s 模型为例,演示如何将 TensorRT-YOLO 的 Deploy 模块集成到 [VideoPipe](https://github.com/sherlockchou86/VideoPipe) 中进行视频分析。
6+
7+
## 模型导出
8+
9+
首先,从 [YOLOv8s](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8s.pt) 下载 YOLOv8s 模型并保存到 `workspace` 文件夹中。
10+
11+
然后,使用以下指令将模型导出为带有 [EfficientNMS](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin) 插件的 ONNX 格式:
12+
13+
```bash
14+
cd workspace
15+
trtyolo export -w yolov8s.pt -v yolov8 -o models -b 2
16+
```
17+
18+
执行以上命令后,将在 `models` 文件夹下生成名为 `yolov8s.onnx` 的文件。接着,使用 `trtexec` 工具将 ONNX 文件转换为 TensorRT engine:
19+
20+
```bash
21+
cd workspace
22+
trtexec --onnx=yolov8s.onnx --saveEngine=yolov8s.engine --fp16
23+
```
24+
25+
## 项目运行
26+
27+
在进行推理之前,请确保已经编译了 VideoPipe 和 TensorRT-YOLO。
28+
29+
接下来,使用 xmake 将项目编译为可执行文件:
30+
31+
```bash
32+
xmake f -P . --tensorrt=/path/to/your/TensorRT --deploy=/path/to/your/TensorRT-YOLO --videopipe=/path/to/your/VideoPipe
33+
34+
xmake -P . -r
35+
```
36+
37+
编译成功后,您可以直接运行生成的可执行文件或使用 `xmake run` 命令进行推理:
38+
39+
```bash
40+
xmake run -P . PipeDemo
41+
```
42+
43+
<div align="center">
44+
<p>
45+
<img width="100%" src="../../assets/videopipe.jpg">
46+
</p>
47+
</div>
48+
49+
以上是进行模型推理的方法示例。

demo/VideoPipe/src/main.cpp

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#include "vp_trtyolo_detector.h"
2+
#include "nodes/vp_split_node.h"
3+
#include "nodes/osd/vp_osd_node.h"
4+
#include "nodes/vp_file_src_node.h"
5+
#include "nodes/vp_screen_des_node.h"
6+
#include "nodes/track/vp_sort_track_node.h"
7+
#include "utils/analysis_board/vp_analysis_board.h"
8+
9+
int main() {
10+
// Disable logging code location and thread ID
11+
VP_SET_LOG_INCLUDE_CODE_LOCATION(false);
12+
VP_SET_LOG_INCLUDE_THREAD_ID(false);
13+
VP_LOGGER_INIT();
14+
15+
// Video sources
16+
auto file_src_0 = std::make_shared<vp_nodes::vp_file_src_node>("file_src_0", 0, "demo0.mp4");
17+
auto file_src_1 = std::make_shared<vp_nodes::vp_file_src_node>("file_src_1", 1, "demo1.mp4");
18+
19+
// Inference node (TensorRT-YOLO detector)
20+
auto detector = std::make_shared<vp_nodes::vp_trtyolo_detector>("yolo_detector", "yolov8s.engine", "labels.txt", true, 2);
21+
22+
// Tracking node (SORT tracker)
23+
auto track = std::make_shared<vp_nodes::vp_sort_track_node>("track");
24+
25+
// OSD (On-Screen Display) node
26+
auto osd = std::make_shared<vp_nodes::vp_osd_node>("osd");
27+
28+
// Channel splitting node
29+
auto split = std::make_shared<vp_nodes::vp_split_node>("split_by_channel", true);
30+
31+
// Local display nodes
32+
auto screen_des_0 = std::make_shared<vp_nodes::vp_screen_des_node>("screen_des_0", 0);
33+
auto screen_des_1 = std::make_shared<vp_nodes::vp_screen_des_node>("screen_des_1", 1);
34+
35+
// Constructing the pipeline
36+
detector->attach_to({file_src_0, file_src_1});
37+
track->attach_to({detector});
38+
osd->attach_to({track});
39+
split->attach_to({osd});
40+
41+
// Splitting by vp_split_node for display
42+
screen_des_0->attach_to({split});
43+
screen_des_1->attach_to({split});
44+
45+
// Start video sources
46+
file_src_0->start();
47+
file_src_1->start();
48+
49+
// Debugging: Display analysis board
50+
vp_utils::vp_analysis_board board({file_src_0, file_src_1});
51+
board.display(1, false); // Display board with refresh rate of 1 second, non-verbose
52+
53+
// Wait for user input to stop and detach nodes recursively
54+
std::string wait;
55+
std::getline(std::cin, wait);
56+
file_src_0->detach_recursively();
57+
file_src_1->detach_recursively();
58+
}
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#include "vp_trtyolo_detector.h"
2+
3+
namespace vp_nodes {
4+
5+
vp_trtyolo_detector::vp_trtyolo_detector(std::string node_name, std::string model_path, std::string labels_path, bool cudagraph, int batch, int device_id)
6+
: vp_primary_infer_node(node_name, "", "", labels_path, 0, 0, batch), use_cudagraph(cudagraph) {
7+
8+
// Initialize detector based on CUDA graph usage
9+
if (use_cudagraph) {
10+
detector = std::make_shared<deploy::DeployCGDet>(model_path, device_id);
11+
if (detector->batch != batch) {
12+
throw std::runtime_error("Batch size mismatch: expected " + std::to_string(batch) + ", but got " + std::to_string(detector->batch));
13+
}
14+
} else {
15+
detector = std::make_shared<deploy::DeployDet>(model_path, device_id);
16+
if (detector->batch < batch) {
17+
throw std::runtime_error("Batch size too large: expected <= " + std::to_string(detector->batch) + ", but got " + std::to_string(batch));
18+
}
19+
}
20+
21+
this->initialized(); // Mark node as initialized
22+
}
23+
24+
vp_trtyolo_detector::~vp_trtyolo_detector() {
25+
// Destructor: Clean up any resources
26+
deinitialized(); // Mark node as deinitialized
27+
}
28+
29+
void vp_trtyolo_detector::run_infer_combinations(const std::vector<std::shared_ptr<vp_objects::vp_frame_meta>>& frame_meta_with_batch) {
30+
if (use_cudagraph)
31+
assert(frame_meta_with_batch.size() == detector->batch); // Assert batch size consistency if using CUDA graph
32+
33+
std::vector<cv::Mat> mats_to_infer;
34+
std::vector<deploy::Image> images_to_infer;
35+
36+
auto start_time = std::chrono::system_clock::now(); // Start time for performance measurement
37+
38+
// Prepare data for inference (same as base class)
39+
vp_primary_infer_node::prepare(frame_meta_with_batch, mats_to_infer);
40+
std::transform(mats_to_infer.begin(), mats_to_infer.end(), std::back_inserter(images_to_infer), [](cv::Mat& mat) {
41+
return deploy::Image(mat.data, mat.cols, mat.rows);
42+
});
43+
44+
auto prepare_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - start_time);
45+
46+
start_time = std::chrono::system_clock::now();
47+
48+
// Perform inference on prepared images
49+
std::vector<deploy::DetectionResult> detection_results = detector->predict(images_to_infer);
50+
51+
// Process detection results and update frame metadata
52+
for (int i = 0; i < detection_results.size(); i++) {
53+
auto& frame_meta = frame_meta_with_batch[i];
54+
auto& detection_result = detection_results[i];
55+
56+
for (int j = 0; j < detection_result.num; j++) {
57+
int x = static_cast<int>(detection_result.boxes[j].left);
58+
int y = static_cast<int>(detection_result.boxes[j].top);
59+
int width = static_cast<int>(detection_result.boxes[j].right - detection_result.boxes[j].left);
60+
int height = static_cast<int>(detection_result.boxes[j].bottom - detection_result.boxes[j].top);
61+
auto label = labels.size() == 0 ? "" : labels[detection_result.classes[j]];
62+
63+
// Create target and update back into frame meta
64+
auto target = std::make_shared<vp_objects::vp_frame_target>(
65+
x, y, width, height, detection_result.classes[j], detection_result.scores[j],
66+
frame_meta->frame_index, frame_meta->channel_index, label);
67+
68+
frame_meta->targets.push_back(target);
69+
}
70+
}
71+
72+
auto infer_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - start_time);
73+
74+
// Cannot calculate preprocess time and postprocess time, set 0 by default.
75+
vp_infer_node::infer_combinations_time_cost(mats_to_infer.size(), prepare_time.count(), 0, infer_time.count(), 0);
76+
}
77+
78+
void vp_trtyolo_detector::postprocess(const std::vector<cv::Mat>& raw_outputs, const std::vector<std::shared_ptr<vp_objects::vp_frame_meta>>& frame_meta_with_batch) {
79+
// Placeholder for postprocessing logic if needed in future enhancements
80+
// Currently not implemented in this class
81+
}
82+
83+
} // namespace vp_nodes

0 commit comments

Comments
 (0)