Skip to content

Commit 93e7287

Browse files
authored
Add EfficientNet (wang-xinyu#590)
* create psenet create psenet with weight from tensorflow * delete some useless code * repalce tab with 4 blanks * fix network bug, rewrite post-processing pse algorithm * update readme * update readme * add RepVGG * fix typo * add hrnetseg w18 w32 w48 * add hrnetseg with ocr w18 w32 w48 * merge hrnet and small, add hrnet_ocr * fix warning * change project name * add efficientnet
1 parent 668d89b commit 93e7287

File tree

6 files changed

+1122
-0
lines changed

6 files changed

+1122
-0
lines changed

Diff for: efficientnet/CMakeLists.txt

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
cmake_minimum_required(VERSION 2.6)
2+
3+
project(efficientnet)
4+
5+
add_definitions(-std=c++11)
6+
7+
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
8+
set(CMAKE_CXX_STANDARD 11)
9+
set(CMAKE_BUILD_TYPE Debug)
10+
11+
find_package(CUDA REQUIRED)
12+
13+
include_directories(${PROJECT_SOURCE_DIR}/include)
14+
# include and link dirs of cuda and tensorrt, you need adapt them if yours are different
15+
# cuda
16+
include_directories(/usr/local/cuda/include)
17+
link_directories(/usr/local/cuda/lib64)
18+
# tensorrt
19+
include_directories(/usr/include/x86_64-linux-gnu/)
20+
link_directories(/usr/lib/x86_64-linux-gnu/)
21+
22+
add_executable(efficientnet ${PROJECT_SOURCE_DIR}/efficientnet.cpp)
23+
target_link_libraries(efficientnet nvinfer)
24+
target_link_libraries(efficientnet cudart)
25+
26+
add_definitions(-O2 -pthread)
27+

Diff for: efficientnet/README.md

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# EfficientNet
2+
3+
A TensorRT implementation of EfficientNet.
4+
For the Pytorch implementation, you can refer to [EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch)
5+
6+
## How to run
7+
8+
1. install `efficientnet_pytorch`
9+
```
10+
pip install efficientnet_pytorch
11+
```
12+
13+
2. gennerate `.wts` file
14+
```
15+
python gen_wts.py
16+
```
17+
18+
3. build
19+
20+
```
21+
mkdir build
22+
cd build
23+
cmake ..
24+
make
25+
```
26+
4. serialize model to engine
27+
```
28+
./efficientnet -s [.wts] [.engine] [b0 b1 b2 b3 ... b7] // serialize model to engine file
29+
```
30+
such as
31+
```
32+
./efficientnet -s ../efficientnet-b3.wts efficientnet-b3.engine b3
33+
```
34+
5. deserialize and do infer
35+
```
36+
./efficientnet -d [.engine] [b0 b1 b2 b3 ... b7] // deserialize engine file and run inference
37+
```
38+
such as
39+
```
40+
./efficientnet -d efficientnet-b3.engine b3
41+
```
42+
6. see if the output is same as pytorch side
43+
44+
45+
For more models, please refer to [tensorrtx](https://github.com/wang-xinyu/tensorrtx)

Diff for: efficientnet/efficientnet.cpp

+280
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
#include "NvInfer.h"
2+
#include "cuda_runtime_api.h"
3+
#include "logging.h"
4+
#include <fstream>
5+
#include <iostream>
6+
#include <map>
7+
#include <sstream>
8+
#include <vector>
9+
#include <chrono>
10+
#include "utils.hpp"
11+
12+
#define USE_FP32 //USE_FP16
13+
#define INPUT_NAME "data"
14+
#define OUTPUT_NAME "prob"
15+
#define MAX_BATCH_SIZE 8
16+
17+
using namespace nvinfer1;
18+
static Logger gLogger;
19+
20+
static std::vector<BlockArgs>
21+
block_args_list = {
22+
BlockArgs{1, 3, 1, 1, 32, 16, 0.25, true},
23+
BlockArgs{2, 3, 2, 6, 16, 24, 0.25, true},
24+
BlockArgs{2, 5, 2, 6, 24, 40, 0.25, true},
25+
BlockArgs{3, 3, 2, 6, 40, 80, 0.25, true},
26+
BlockArgs{3, 5, 1, 6, 80, 112, 0.25, true},
27+
BlockArgs{4, 5, 2, 6, 112, 192, 0.25, true},
28+
BlockArgs{1, 3, 1, 6, 192, 320, 0.25, true}};
29+
30+
static std::map<std::string, GlobalParams>
31+
global_params_map = {
32+
// input_h,input_w,num_classes,batch_norm_epsilon,
33+
// width_coefficient,depth_coefficient,depth_divisor, min_depth
34+
{"b0", GlobalParams{224, 224, 1000, 0.001, 1.0, 1.0, 8, -1}},
35+
{"b1", GlobalParams{240, 240, 1000, 0.001, 1.0, 1.1, 8, -1}},
36+
{"b2", GlobalParams{260, 260, 1000, 0.001, 1.1, 1.2, 8, -1}},
37+
{"b3", GlobalParams{300, 300, 1000, 0.001, 1.2, 1.4, 8, -1}},
38+
{"b4", GlobalParams{380, 380, 1000, 0.001, 1.4, 1.8, 8, -1}},
39+
{"b5", GlobalParams{456, 456, 1000, 0.001, 1.6, 2.2, 8, -1}},
40+
{"b6", GlobalParams{528, 528, 1000, 0.001, 1.8, 2.6, 8, -1}},
41+
{"b7", GlobalParams{600, 600, 1000, 0.001, 2.0, 3.1, 8, -1}},
42+
{"b8", GlobalParams{672, 672, 1000, 0.001, 2.2, 3.6, 8, -1}},
43+
{"l2", GlobalParams{800, 800, 1000, 0.001, 4.3, 5.3, 8, -1}},
44+
};
45+
46+
ICudaEngine *createEngine(unsigned int maxBatchSize, IBuilder *builder, IBuilderConfig *config, DataType dt, std::string path_wts, std::vector<BlockArgs> block_args_list, GlobalParams global_params)
47+
{
48+
float bn_eps = global_params.batch_norm_epsilon;
49+
DimsHW image_size = DimsHW{global_params.input_h, global_params.input_w};
50+
51+
std::map<std::string, Weights> weightMap = loadWeights(path_wts);
52+
Weights emptywts{DataType::kFLOAT, nullptr, 0};
53+
INetworkDefinition *network = builder->createNetworkV2(0U);
54+
ITensor *data = network->addInput(INPUT_NAME, dt, Dims3{3, global_params.input_h, global_params.input_w});
55+
assert(data);
56+
57+
int out_channels = roundFilters(32, global_params);
58+
auto conv_stem = addSamePaddingConv2d(network, weightMap, *data, out_channels, 3, 2, 1, 1, image_size, "_conv_stem");
59+
auto bn0 = addBatchNorm2d(network, weightMap, *conv_stem->getOutput(0), "_bn0", bn_eps);
60+
auto swish0 = addSwish(network, *bn0->getOutput(0));
61+
ITensor *x = swish0->getOutput(0);
62+
image_size = calculateOutputImageSize(image_size, 2);
63+
int block_id = 0;
64+
for (int i = 0; i < block_args_list.size(); i++)
65+
{
66+
BlockArgs block_args = block_args_list[i];
67+
68+
block_args.input_filters = roundFilters(block_args.input_filters, global_params);
69+
block_args.output_filters = roundFilters(block_args.output_filters, global_params);
70+
block_args.num_repeat = roundRepeats(block_args.num_repeat, global_params);
71+
x = MBConvBlock(network, weightMap, *x, "_blocks." + std::to_string(block_id), block_args, global_params, image_size);
72+
73+
assert(x);
74+
block_id++;
75+
image_size = calculateOutputImageSize(image_size, block_args.stride);
76+
if (block_args.num_repeat > 1)
77+
{
78+
block_args.input_filters = block_args.output_filters;
79+
block_args.stride = 1;
80+
}
81+
for (int r = 0; r < block_args.num_repeat - 1; r++)
82+
{
83+
x = MBConvBlock(network, weightMap, *x, "_blocks." + std::to_string(block_id), block_args, global_params, image_size);
84+
block_id++;
85+
}
86+
}
87+
out_channels = roundFilters(1280, global_params);
88+
auto conv_head = addSamePaddingConv2d(network, weightMap, *x, out_channels, 1, 1, 1, 1, image_size, "_conv_head", false);
89+
auto bn1 = addBatchNorm2d(network, weightMap, *conv_head->getOutput(0), "_bn1", bn_eps);
90+
auto swish1 = addSwish(network, *bn1->getOutput(0));
91+
auto avg_pool = network->addPoolingNd(*swish1->getOutput(0), PoolingType::kAVERAGE, image_size);
92+
93+
IFullyConnectedLayer *final = network->addFullyConnected(*avg_pool->getOutput(0), global_params.num_classes, weightMap["_fc.weight"], weightMap["_fc.bias"]);
94+
assert(final);
95+
96+
final->getOutput(0)->setName(OUTPUT_NAME);
97+
network->markOutput(*final->getOutput(0));
98+
99+
// Build engine
100+
builder->setMaxBatchSize(maxBatchSize);
101+
config->setMaxWorkspaceSize(1 << 20);
102+
#ifdef USE_FP16
103+
config->setFlag(BuilderFlag::kFP16);
104+
#endif
105+
std::cout << "build engine ..." << std::endl;
106+
107+
ICudaEngine *engine = builder->buildEngineWithConfig(*network, *config);
108+
assert(engine != nullptr);
109+
110+
std::cout << "build finished" << std::endl;
111+
// Don't need the network any more
112+
network->destroy();
113+
// Release host memory
114+
for (auto &mem : weightMap)
115+
{
116+
free((void *)(mem.second.values));
117+
}
118+
119+
return engine;
120+
}
121+
122+
void APIToModel(unsigned int maxBatchSize, IHostMemory **modelStream, std::string wtsPath, std::vector<BlockArgs> block_args_list, GlobalParams global_params)
123+
{
124+
// Create builder
125+
IBuilder *builder = createInferBuilder(gLogger);
126+
IBuilderConfig *config = builder->createBuilderConfig();
127+
128+
// Create model to populate the network, then set the outputs and create an engine
129+
ICudaEngine *engine = createEngine(maxBatchSize, builder, config, DataType::kFLOAT, wtsPath, block_args_list, global_params);
130+
assert(engine != nullptr);
131+
132+
// Serialize the engine
133+
(*modelStream) = engine->serialize();
134+
135+
// Close everything down
136+
engine->destroy();
137+
builder->destroy();
138+
config->destroy();
139+
}
140+
void doInference(IExecutionContext &context, float *input, float *output, int batchSize, GlobalParams global_params)
141+
{
142+
const ICudaEngine &engine = context.getEngine();
143+
144+
// Pointers to input and output device buffers to pass to engine.
145+
// Engine requires exactly IEngine::getNbBindings() number of buffers.
146+
assert(engine.getNbBindings() == 2);
147+
void *buffers[2];
148+
149+
// In order to bind the buffers, we need to know the names of the input and output tensors.
150+
// Note that indices are guaranteed to be less than IEngine::getNbBindings()
151+
const int inputIndex = engine.getBindingIndex(INPUT_NAME);
152+
const int outputIndex = engine.getBindingIndex(OUTPUT_NAME);
153+
154+
// Create GPU buffers on device
155+
CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 3 * global_params.input_h * global_params.input_w * sizeof(float)));
156+
CHECK(cudaMalloc(&buffers[outputIndex], batchSize * global_params.num_classes * sizeof(float)));
157+
158+
// Create stream
159+
cudaStream_t stream;
160+
CHECK(cudaStreamCreate(&stream));
161+
162+
// DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
163+
CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * global_params.input_h * global_params.input_w * sizeof(float), cudaMemcpyHostToDevice, stream));
164+
context.enqueue(batchSize, buffers, stream, nullptr);
165+
CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * global_params.num_classes * sizeof(float), cudaMemcpyDeviceToHost, stream));
166+
cudaStreamSynchronize(stream);
167+
168+
// Release stream and buffers
169+
cudaStreamDestroy(stream);
170+
CHECK(cudaFree(buffers[inputIndex]));
171+
CHECK(cudaFree(buffers[outputIndex]));
172+
}
173+
174+
bool parse_args(int argc, char **argv, std::string &wts, std::string &engine, std::string &backbone)
175+
{
176+
if (std::string(argv[1]) == "-s" && argc == 5)
177+
{
178+
wts = std::string(argv[2]);
179+
engine = std::string(argv[3]);
180+
backbone = std::string(argv[4]);
181+
}
182+
else if (std::string(argv[1]) == "-d" && argc == 4)
183+
{
184+
engine = std::string(argv[2]);
185+
backbone = std::string(argv[3]);
186+
}
187+
else
188+
{
189+
return false;
190+
}
191+
return true;
192+
}
193+
194+
int main(int argc, char **argv)
195+
{
196+
std::string wtsPath = "";
197+
std::string engine_name = "";
198+
std::string backbone = "";
199+
if (!parse_args(argc, argv, wtsPath, engine_name, backbone))
200+
{
201+
std::cerr << "arguments not right!" << std::endl;
202+
std::cerr << "./efficientnet -s [.wts] [.engine] [b0 b1 b2 b3 ... b7] // serialize model to engine file" << std::endl;
203+
std::cerr << "./efficientnet -d [.engine] [b0 b1 b2 b3 ... b7] // deserialize engine file and run inference" << std::endl;
204+
return -1;
205+
}
206+
GlobalParams global_params = global_params_map[backbone];
207+
// create a model using the API directly and serialize it to a stream
208+
if (!wtsPath.empty())
209+
{
210+
IHostMemory *modelStream{nullptr};
211+
APIToModel(MAX_BATCH_SIZE, &modelStream, wtsPath, block_args_list, global_params);
212+
assert(modelStream != nullptr);
213+
214+
std::ofstream p(engine_name, std::ios::binary);
215+
if (!p)
216+
{
217+
std::cerr << "could not open plan output file" << std::endl;
218+
return -1;
219+
}
220+
p.write(reinterpret_cast<const char *>(modelStream->data()), modelStream->size());
221+
modelStream->destroy();
222+
return 1;
223+
}
224+
225+
char *trtModelStream{nullptr};
226+
size_t size{0};
227+
228+
std::ifstream file(engine_name, std::ios::binary);
229+
if (file.good())
230+
{
231+
file.seekg(0, file.end);
232+
size = file.tellg();
233+
file.seekg(0, file.beg);
234+
trtModelStream = new char[size];
235+
assert(trtModelStream);
236+
file.read(trtModelStream, size);
237+
file.close();
238+
}
239+
else
240+
{
241+
std::cerr << "could not open plan file" << std::endl;
242+
return -1;
243+
}
244+
245+
// dummy input
246+
float *data = new float[3 * global_params.input_h * global_params.input_w];
247+
for (int i = 0; i < 3 * global_params.input_h * global_params.input_w; i++)
248+
data[i] = 0.1;
249+
250+
IRuntime *runtime = createInferRuntime(gLogger);
251+
assert(runtime != nullptr);
252+
ICudaEngine *engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr);
253+
assert(engine != nullptr);
254+
IExecutionContext *context = engine->createExecutionContext();
255+
assert(context != nullptr);
256+
delete[] trtModelStream;
257+
258+
// Run inference
259+
float *prob = new float[global_params.num_classes];
260+
for (int i = 0; i < 100; i++)
261+
{
262+
auto start = std::chrono::system_clock::now();
263+
doInference(*context, data, prob, 1, global_params);
264+
auto end = std::chrono::system_clock::now();
265+
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
266+
}
267+
for (unsigned int i = 0; i < 20; i++)
268+
{
269+
std::cout << prob[i] << ", ";
270+
}
271+
std::cout << std::endl;
272+
// Destroy the engine
273+
context->destroy();
274+
engine->destroy();
275+
runtime->destroy();
276+
delete data;
277+
delete prob;
278+
279+
return 0;
280+
}

Diff for: efficientnet/gen_wts.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
import struct
3+
from efficientnet_pytorch import EfficientNet
4+
model = EfficientNet.from_pretrained('efficientnet-b3')
5+
6+
model.eval()
7+
f = open('efficientnet-b3.wts', 'w')
8+
f.write('{}\n'.format(len(model.state_dict().keys())))
9+
for k, v in model.state_dict().items():
10+
vr = v.reshape(-1).cpu().numpy()
11+
f.write('{} {} '.format(k, len(vr)))
12+
for vv in vr:
13+
f.write(' ')
14+
f.write(struct.pack('>f',float(vv)).hex())
15+
f.write('\n')
16+
f.close()

0 commit comments

Comments
 (0)