Skip to content

Commit 6161f28

Browse files
authored
add unet code (wang-xinyu#380)
* add code * remove txt
1 parent 4ae3e87 commit 6161f28

File tree

6 files changed

+1157
-0
lines changed

6 files changed

+1157
-0
lines changed

unet/CMakeLists.txt

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
cmake_minimum_required(VERSION 2.6)
2+
3+
project(unet)
4+
5+
add_definitions(-std=c++11)
6+
7+
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
8+
set(CMAKE_CXX_STANDARD 11)
9+
set(CMAKE_BUILD_TYPE Debug)
10+
11+
set(CUDA_NVCC_PLAGS ${CUDA_NVCC_PLAGS};-std=c++11;-g;-G;-gencode;arch=compute_30;code=sm_30)
12+
13+
# cuda directory
14+
include_directories(${PROJECT_SOURCE_DIR}/include)
15+
include_directories(/usr/local/cuda-10.2/targets/x86_64-linux/include)
16+
link_directories(/usr/local/cuda-10.2/targets/x86_64-linux/lib)
17+
18+
# tensorrt
19+
include_directories(/home/sycv/workplace/pengyuzhou/TensorRT-7.0.0.11/targets/x86_64-linux-gnu/include)
20+
link_directories(/home/sycv/workplace/pengyuzhou/TensorRT-7.0.0.11/targets/x86_64-linux-gnu/lib)
21+
22+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Ofast -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED")
23+
24+
# link library and add exec file
25+
add_executable(unet ${PROJECT_SOURCE_DIR}/unet.cpp)
26+
target_link_libraries(unet nvinfer)
27+
target_link_libraries(unet cudart)
28+
29+
add_definitions(-O2 -pthread)
30+
31+
# opencv library
32+
find_package(OpenCV)
33+
include_directories(OpenCV_INCLUDE_DIRS)
34+
target_link_libraries(unet ${OpenCV_LIBS})

unet/README.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# tensorrt-unet
2+
This is a TensorRT version Unet, inspired by [tensorrtx](https://github.com/wang-xinyu/tensorrtx) and [pytorch-unet](https://github.com/milesial/Pytorch-UNet).<br>
3+
You can generate TensorRT engine file using this script and customize some params and network structure based on network you trained (FP32/16 precision, input size, different conv, activation function...)<br>
4+
5+
# requirements
6+
7+
TensorRT 7.0 (you need to install tensorrt first)<br>
8+
Cuda 10.2<br>
9+
Python3.7<br>
10+
opencv 4.4<br>
11+
cmake 3.18<br>
12+
# train .pth file and convert .wts
13+
14+
## create env
15+
16+
```
17+
pip install -r requirements.txt
18+
```
19+
20+
## train .pth file
21+
22+
train your dataset by following [pytorch-unet](https://github.com/milesial/Pytorch-UNet) and generate .pth file.<br>
23+
24+
## convert .wts
25+
26+
run gen_wts from utils folder, and move it to project folder<br>
27+
28+
# generate engine file and infer
29+
30+
## create build folder in project folder
31+
```
32+
mkdir build
33+
```
34+
35+
## make file, generate exec file
36+
```
37+
cd build
38+
cmake ..
39+
make
40+
```
41+
42+
## generate TensorRT engine file and infer image
43+
```
44+
unet -s
45+
```
46+
then a unet exec file will generated, you can use unet -d to infer files in a folder<br>
47+
```
48+
unet -d ../samples
49+
```
50+
51+
# efficiency
52+
the speed of tensorRT engine is much faster
53+
54+
pytorch | TensorRT FP32 | TensorRT FP16
55+
---- | ----- | ------
56+
816x672 | 816x672 | 816x672
57+
58ms | 43ms (batchsize 8) | 14ms (batchsize 8)
58+
59+
60+
# Further development
61+
62+
1. add INT8 calibrator<br>
63+
2. add custom plugin<br>
64+
etc

unet/common.hpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#ifndef UNET_COMMON_H_
2+
#define UNET_COMMON_H_
3+
4+
#include <fstream>
5+
#include <map>
6+
#include <sstream>
7+
#include <vector>
8+
#include <opencv2/opencv.hpp>
9+
#include <dirent.h>
10+
#include "NvInfer.h"
11+
12+
13+
#define CHECK(status) \
14+
do\
15+
{\
16+
auto ret = (status);\
17+
if (ret != 0)\
18+
{\
19+
std::cerr << "Cuda failure: " << ret << std::endl;\
20+
abort();\
21+
}\
22+
} while (0)
23+
24+
using namespace nvinfer1;
25+
26+
27+
28+
29+
30+
// TensorRT weight files have a simple space delimited format:
31+
// [type] [size] <data x size in hex>
32+
std::map<std::string, Weights> loadWeights(const std::string file) {
33+
std::cout << "Loading weights: " << file << std::endl;
34+
std::map<std::string, Weights> weightMap;
35+
36+
// Open weights file
37+
std::ifstream input(file);
38+
assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!");
39+
40+
// Read number of weight blobs
41+
int32_t count;
42+
input >> count;
43+
assert(count > 0 && "Invalid weight map file.");
44+
45+
while (count--)
46+
{
47+
Weights wt{DataType::kFLOAT, nullptr, 0};
48+
uint32_t size;
49+
50+
// Read name and type of blob
51+
std::string name;
52+
input >> name >> std::dec >> size;
53+
wt.type = DataType::kFLOAT;
54+
55+
// Load blob
56+
uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
57+
for (uint32_t x = 0, y = size; x < y; ++x)
58+
{
59+
input >> std::hex >> val[x];
60+
}
61+
wt.values = val;
62+
63+
wt.count = size;
64+
weightMap[name] = wt;
65+
}
66+
67+
return weightMap;
68+
}
69+
70+
IScaleLayer* addBatchNorm2d(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, std::string lname, float eps) {
71+
float *gamma = (float*)weightMap[lname + ".weight"].values;
72+
float *beta = (float*)weightMap[lname + ".bias"].values;
73+
float *mean = (float*)weightMap[lname + ".running_mean"].values;
74+
float *var = (float*)weightMap[lname + ".running_var"].values;
75+
int len = weightMap[lname + ".running_var"].count;
76+
77+
float *scval = reinterpret_cast<float*>(malloc(sizeof(float) * len));
78+
for (int i = 0; i < len; i++) {
79+
scval[i] = gamma[i] / sqrt(var[i] + eps);
80+
}
81+
Weights scale{DataType::kFLOAT, scval, len};
82+
83+
float *shval = reinterpret_cast<float*>(malloc(sizeof(float) * len));
84+
for (int i = 0; i < len; i++) {
85+
shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps);
86+
}
87+
Weights shift{DataType::kFLOAT, shval, len};
88+
89+
float *pval = reinterpret_cast<float*>(malloc(sizeof(float) * len));
90+
for (int i = 0; i < len; i++) {
91+
pval[i] = 1.0;
92+
}
93+
Weights power{DataType::kFLOAT, pval, len};
94+
95+
weightMap[lname + ".scale"] = scale;
96+
weightMap[lname + ".shift"] = shift;
97+
weightMap[lname + ".power"] = power;
98+
IScaleLayer* scale_1 = network->addScale(input, ScaleMode::kCHANNEL, shift, scale, power);
99+
assert(scale_1);
100+
return scale_1;
101+
}
102+
103+
104+
ILayer* convBlock(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, int outch, int ksize, int s, int g, std::string lname) {
105+
Weights emptywts{DataType::kFLOAT, nullptr, 0};
106+
int p = ksize / 2;
107+
IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, weightMap[lname + ".conv.weight"], emptywts);
108+
assert(conv1);
109+
conv1->setStrideNd(DimsHW{s, s});
110+
conv1->setPaddingNd(DimsHW{p, p});
111+
conv1->setNbGroups(g);
112+
IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + ".bn", 1e-3);
113+
114+
// hard_swish = x * hard_sigmoid
115+
auto hsig = network->addActivation(*bn1->getOutput(0), ActivationType::kHARD_SIGMOID);
116+
assert(hsig);
117+
hsig->setAlpha(1.0 / 6.0);
118+
hsig->setBeta(0.5);
119+
auto ew = network->addElementWise(*bn1->getOutput(0), *hsig->getOutput(0), ElementWiseOperation::kPROD);
120+
assert(ew);
121+
return ew;
122+
}
123+
124+
125+
126+
int read_files_in_dir(const char *p_dir_name, std::vector<std::string> &file_names) {
127+
DIR *p_dir = opendir(p_dir_name);
128+
if (p_dir == nullptr) {
129+
return -1;
130+
}
131+
132+
struct dirent* p_file = nullptr;
133+
while ((p_file = readdir(p_dir)) != nullptr) {
134+
if (strcmp(p_file->d_name, ".") != 0 &&
135+
strcmp(p_file->d_name, "..") != 0) {
136+
//std::string cur_file_name(p_dir_name);
137+
//cur_file_name += "/";
138+
//cur_file_name += p_file->d_name;
139+
std::string cur_file_name(p_file->d_name);
140+
file_names.push_back(cur_file_name);
141+
}
142+
}
143+
144+
closedir(p_dir);
145+
return 0;
146+
}
147+
148+
#endif

unet/gen_wts.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from torch import nn
3+
import torchvision
4+
import os
5+
import struct
6+
from torchsummary import summary
7+
8+
def main():
9+
print('cuda device count: ', torch.cuda.device_count())
10+
net = torch.load('ori_unet.pth')
11+
net = net.to('cuda:0')
12+
net = net.eval()
13+
print('model: ', net)
14+
#print('state dict: ', net.state_dict().keys())
15+
tmp = torch.ones(1, 3, 224, 224).to('cuda:0')
16+
print('input: ', tmp)
17+
out = net(tmp)
18+
19+
print('output:', out)
20+
21+
summary(net, (3, 224, 224))
22+
#return
23+
f = open("unet.wts", 'w')
24+
f.write("{}\n".format(len(net.state_dict().keys())))
25+
for k,v in net.state_dict().items():
26+
print('key: ', k)
27+
print('value: ', v.shape)
28+
vr = v.reshape(-1).cpu().numpy()
29+
f.write("{} {}".format(k, len(vr)))
30+
for vv in vr:
31+
f.write(" ")
32+
f.write(struct.pack(">f", float(vv)).hex())
33+
f.write("\n")
34+
35+
if __name__ == '__main__':
36+
main()

0 commit comments

Comments
 (0)