Skip to content
This repository was archived by the owner on Sep 9, 2025. It is now read-only.

Commit b6bf899

Browse files
authored
add cast trt converter (PaddlePaddle#43447)
* add cast trt converter
1 parent 8902a41 commit b6bf899

File tree

5 files changed

+331
-67
lines changed

5 files changed

+331
-67
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 94 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ bool IsPersistable(const framework::VarDesc *var) {
104104
}
105105
} // namespace
106106

107-
bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
107+
bool PaddleTensorToLoDTensor(const PaddleTensor &pt,
108+
framework::LoDTensor *t,
108109
const platform::Place &place) {
109110
framework::DDim ddim = phi::make_ddim(pt.shape);
110111
void *input_ptr;
@@ -132,27 +133,31 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
132133

133134
if (platform::is_cpu_place(place)) {
134135
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
135-
std::memcpy(static_cast<void *>(input_ptr), pt.data.data(),
136-
pt.data.length());
136+
std::memcpy(
137+
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
137138
} else if (platform::is_ipu_place(place)) {
138139
#ifdef PADDLE_WITH_IPU
139-
std::memcpy(static_cast<void *>(input_ptr), pt.data.data(),
140-
pt.data.length());
140+
std::memcpy(
141+
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
141142
#else
142143
PADDLE_THROW(paddle::platform::errors::Fatal(
143144
"Not compile with WITH_IPU, should not reach here."));
144145
#endif
145146
} else if (platform::is_gpu_place(place)) {
146-
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), false,
147+
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place),
148+
false,
147149
platform::errors::InvalidArgument(
148150
"Only one choice can be made between CPU and XPU."));
149151
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
150152
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
151153
auto *dev_ctx =
152154
static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
153155
auto dst_gpu_place = place;
154-
memory::Copy(dst_gpu_place, static_cast<void *>(input_ptr),
155-
platform::CPUPlace(), pt.data.data(), pt.data.length(),
156+
memory::Copy(dst_gpu_place,
157+
static_cast<void *>(input_ptr),
158+
platform::CPUPlace(),
159+
pt.data.data(),
160+
pt.data.length(),
156161
dev_ctx->stream());
157162
#else
158163
PADDLE_THROW(paddle::platform::errors::Fatal(
@@ -161,8 +166,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
161166
} else if (platform::is_xpu_place(place)) {
162167
#ifdef PADDLE_WITH_XPU
163168
auto dst_xpu_place = place;
164-
memory::Copy(dst_xpu_place, static_cast<void *>(input_ptr),
165-
platform::CPUPlace(), pt.data.data(), pt.data.length());
169+
memory::Copy(dst_xpu_place,
170+
static_cast<void *>(input_ptr),
171+
platform::CPUPlace(),
172+
pt.data.data(),
173+
pt.data.length());
166174
#else
167175
PADDLE_THROW(paddle::platform::errors::Fatal(
168176
"Not compile with XPU, should not reach here."));
@@ -245,7 +253,8 @@ bool AnalysisPredictor::Init(
245253

246254
void AnalysisPredictor::InitPlace() {
247255
if (config_.use_gpu()) {
248-
PADDLE_ENFORCE_EQ(config_.use_xpu(), false,
256+
PADDLE_ENFORCE_EQ(config_.use_xpu(),
257+
false,
249258
platform::errors::InvalidArgument(
250259
"Only one choice can be made between CPU and XPU."));
251260
place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
@@ -502,7 +511,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
502511
}
503512

504513
static void DisablePrepareDataOpt(
505-
std::shared_ptr<framework::ProgramDesc> inference_program, int block,
514+
std::shared_ptr<framework::ProgramDesc> inference_program,
515+
int block,
506516
bool pre_disable_opt) {
507517
bool disable_opt = false;
508518
auto &infer_block = inference_program->Block(block);
@@ -512,8 +522,8 @@ static void DisablePrepareDataOpt(
512522
}
513523
if (op->HasAttr("sub_block")) {
514524
int blockID = op->GetBlockAttrId("sub_block");
515-
DisablePrepareDataOpt(inference_program, blockID,
516-
disable_opt || pre_disable_opt);
525+
DisablePrepareDataOpt(
526+
inference_program, blockID, disable_opt || pre_disable_opt);
517527
}
518528
// disable prepare data if unfriendly op is found
519529
if (!disable_opt) {
@@ -531,8 +541,8 @@ bool AnalysisPredictor::PrepareExecutor() {
531541
#endif
532542
DisablePrepareDataOpt(inference_program_, 0, false);
533543

534-
executor_->Prepare(sub_scope_, *inference_program_, 0,
535-
config_.use_feed_fetch_ops_);
544+
executor_->Prepare(
545+
sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_);
536546

537547
PADDLE_ENFORCE_NOT_NULL(sub_scope_,
538548
platform::errors::PreconditionNotMet(
@@ -578,8 +588,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() {
578588
feed_fetch_vars.emplace_back(pair.second);
579589
}
580590
fleet_exe_->Init(config_.dist_config().carrier_id(),
581-
*(inference_program_.get()), scope_.get(), place_, 1,
582-
{task_node_.get()}, id_to_rank, feed_fetch_vars);
591+
*(inference_program_.get()),
592+
scope_.get(),
593+
place_,
594+
1,
595+
{task_node_.get()},
596+
id_to_rank,
597+
feed_fetch_vars);
583598
return true;
584599
}
585600

@@ -616,8 +631,12 @@ bool AnalysisPredictor::CommInit() {
616631
peer_endpoints.emplace_back(
617632
config_.dist_config().trainer_endpoints()[rank]);
618633
}
619-
InsertCommOp(var_name_base + std::to_string(order), ranks_in_group,
620-
rank_in_group, peer_endpoints, comm_init_block, ring_id);
634+
InsertCommOp(var_name_base + std::to_string(order),
635+
ranks_in_group,
636+
rank_in_group,
637+
peer_endpoints,
638+
comm_init_block,
639+
ring_id);
621640
order += 1;
622641
}
623642
framework::NaiveExecutor e(place_);
@@ -629,8 +648,11 @@ bool AnalysisPredictor::CommInit() {
629648
}
630649

631650
void AnalysisPredictor::InsertCommOp(
632-
std::string tmp_var_name, int nranks, int rank,
633-
const std::vector<std::string> &peer_endpoints, framework::BlockDesc *block,
651+
std::string tmp_var_name,
652+
int nranks,
653+
int rank,
654+
const std::vector<std::string> &peer_endpoints,
655+
framework::BlockDesc *block,
634656
int ring_id) {
635657
/*
636658
* tmp_var_name: the var name for var comm_id
@@ -687,7 +709,8 @@ bool AnalysisPredictor::LoadConverterConfig(
687709
<< config_.dist_config().comm_init_config() << "\n";
688710
std::ifstream fin(config_.dist_config().comm_init_config(), std::ios::in);
689711
PADDLE_ENFORCE_EQ(
690-
static_cast<bool>(fin.is_open()), true,
712+
static_cast<bool>(fin.is_open()),
713+
true,
691714
platform::errors::NotFound(
692715
"Cannot open file %s, please confirm whether the file is normal.",
693716
config_.dist_config().comm_init_config()));
@@ -831,8 +854,9 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
831854
timer.tic();
832855
// set feed variable
833856
framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get();
834-
PADDLE_ENFORCE_NOT_NULL(scope, platform::errors::PreconditionNotMet(
835-
"The scope should not be nullptr."));
857+
PADDLE_ENFORCE_NOT_NULL(
858+
scope,
859+
platform::errors::PreconditionNotMet("The scope should not be nullptr."));
836860
if (!SetFeed(inputs, scope)) {
837861
LOG(ERROR) << "fail to set feed";
838862
return false;
@@ -935,9 +959,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
935959
for (size_t i = 0; i < fetches_.size(); ++i) {
936960
int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col"));
937961
PADDLE_ENFORCE_EQ(
938-
static_cast<size_t>(idx), i,
962+
static_cast<size_t>(idx),
963+
i,
939964
platform::errors::InvalidArgument(
940-
"Fetch op's col attr(%d) should be equal to the index(%d)", idx,
965+
"Fetch op's col attr(%d) should be equal to the index(%d)",
966+
idx,
941967
i));
942968
framework::FetchType &fetch_var =
943969
framework::GetFetchVariable(*scope, "fetch", idx);
@@ -978,7 +1004,8 @@ void AnalysisPredictor::PrepareArgument() {
9781004
if (!config_.model_dir().empty()) {
9791005
argument_.SetModelDir(config_.model_dir());
9801006
} else {
981-
PADDLE_ENFORCE_EQ(config_.prog_file().empty(), false,
1007+
PADDLE_ENFORCE_EQ(config_.prog_file().empty(),
1008+
false,
9821009
platform::errors::PreconditionNotMet(
9831010
"Either model_dir or prog_file should be set."));
9841011
std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
@@ -1123,7 +1150,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
11231150
Analyzer().Run(&argument_);
11241151

11251152
PADDLE_ENFORCE_EQ(
1126-
argument_.scope_valid(), true,
1153+
argument_.scope_valid(),
1154+
true,
11271155
platform::errors::InvalidArgument("The argument scope should be valid."));
11281156
VLOG(5) << "to prepare executor";
11291157
ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program);
@@ -1173,7 +1201,8 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
11731201
}
11741202
VLOG(3) << "create AnalysisConfig";
11751203
PADDLE_ENFORCE_EQ(
1176-
config.is_valid(), true,
1204+
config.is_valid(),
1205+
true,
11771206
platform::errors::InvalidArgument(
11781207
"Note: Each config can only be used for one predictor."));
11791208

@@ -1190,11 +1219,13 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
11901219
std::call_once(gflags_initialized, [&]() {
11911220
std::vector<std::string> gflags;
11921221
PADDLE_ENFORCE_GE(
1193-
config.memory_pool_init_size_mb(), 0.f,
1222+
config.memory_pool_init_size_mb(),
1223+
0.f,
11941224
platform::errors::InvalidArgument(
11951225
"The size of memory pool should be greater than 0."));
11961226
PADDLE_ENFORCE_GE(
1197-
config.gpu_device_id(), 0,
1227+
config.gpu_device_id(),
1228+
0,
11981229
platform::errors::InvalidArgument(
11991230
"Invalid device id (%d). The device id should be greater than 0.",
12001231
config.gpu_device_id()));
@@ -1303,8 +1334,9 @@ void AnalysisPredictor::PrepareFeedFetch() {
13031334
}
13041335

13051336
void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) {
1306-
PADDLE_ENFORCE_NOT_NULL(scope, platform::errors::InvalidArgument(
1307-
"The scope should not be nullptr."));
1337+
PADDLE_ENFORCE_NOT_NULL(
1338+
scope,
1339+
platform::errors::InvalidArgument("The scope should not be nullptr."));
13081340
auto *var = scope->Var("feed");
13091341
var->GetMutable<framework::FeedList>();
13101342
var = scope->Var("fetch");
@@ -1325,8 +1357,9 @@ AnalysisPredictor::GetInputTensorShape() {
13251357
std::vector<std::string> names = GetInputNames();
13261358
for (std::string name : names) {
13271359
auto *var = inference_program_->Block(0).FindVar(name);
1328-
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet(
1329-
"Input %s does not exist.", name));
1360+
PADDLE_ENFORCE_NOT_NULL(
1361+
var,
1362+
platform::errors::PreconditionNotMet("Input %s does not exist.", name));
13301363
input_shapes[name] = var->GetShape();
13311364
}
13321365
return input_shapes;
@@ -1565,7 +1598,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
15651598
std::vector<std::pair<int32_t, int32_t>> counter;
15661599
for (auto &it : m) counter.push_back(it);
15671600
std::sort(
1568-
counter.begin(), counter.end(),
1601+
counter.begin(),
1602+
counter.end(),
15691603
[](std::pair<int32_t, int32_t> &a, std::pair<int32_t, int32_t> &b) {
15701604
return a.second > b.second;
15711605
});
@@ -1587,8 +1621,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
15871621
opt_shapes[name] = opt_shape;
15881622
}
15891623

1590-
inference::SerializeShapeRangeInfo(config_.shape_range_info_path(),
1591-
min_shapes, max_shapes, opt_shapes);
1624+
inference::SerializeShapeRangeInfo(
1625+
config_.shape_range_info_path(), min_shapes, max_shapes, opt_shapes);
15921626
}
15931627

15941628
bool AnalysisPredictor::LoadProgramDesc() {
@@ -1608,7 +1642,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
16081642
return false;
16091643
}
16101644
LOG(ERROR) << string::Sprintf(
1611-
"not valid model path '%s' or program path '%s'.", config_.model_dir(),
1645+
"not valid model path '%s' or program path '%s'.",
1646+
config_.model_dir(),
16121647
config_.params_file());
16131648
return false;
16141649
}
@@ -1620,7 +1655,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
16201655
// Read binary
16211656
std::ifstream fin(filename, std::ios::in | std::ios::binary);
16221657
PADDLE_ENFORCE_EQ(
1623-
static_cast<bool>(fin.is_open()), true,
1658+
static_cast<bool>(fin.is_open()),
1659+
true,
16241660
platform::errors::NotFound(
16251661
"Cannot open file %s, please confirm whether the file is normal.",
16261662
filename));
@@ -1722,7 +1758,8 @@ void AnalysisPredictor::ClearIntermediateTensor() {
17221758

17231759
#if PADDLE_WITH_TENSORRT
17241760
bool AnalysisPredictor::SaveTrtCalibToDisk() {
1725-
PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(), true,
1761+
PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(),
1762+
true,
17261763
platform::errors::PreconditionNotMet(
17271764
"This func can be invoked only in trt mode"));
17281765
auto &block = inference_program_->Block(0);
@@ -1963,6 +2000,7 @@ USE_TRT_CONVERTER(c_allreduce_sum)
19632000
USE_TRT_CONVERTER(roll)
19642001
USE_TRT_CONVERTER(strided_slice)
19652002
USE_TRT_CONVERTER(transformer_input_convert)
2003+
USE_TRT_CONVERTER(cast)
19662004
USE_TRT_CONVERTER(recover_padding)
19672005
USE_TRT_CONVERTER(remove_padding)
19682006
USE_TRT_CONVERTER(top_k)
@@ -1990,8 +2028,10 @@ Predictor::Predictor(const Config &config) {
19902028
<< "Paddle2ONNX do't support convert the Model, fall back to using "
19912029
"Paddle Inference.";
19922030
} else {
1993-
predictor_ = paddle::CreatePaddlePredictor<
1994-
Config, paddle::PaddleEngineKind::kONNXRuntime>(config);
2031+
predictor_ =
2032+
paddle::CreatePaddlePredictor<Config,
2033+
paddle::PaddleEngineKind::kONNXRuntime>(
2034+
config);
19952035
return;
19962036
}
19972037
#else
@@ -2001,8 +2041,10 @@ Predictor::Predictor(const Config &config) {
20012041
"fall back to using Paddle Inference.";
20022042
#endif
20032043
}
2004-
predictor_ = paddle::CreatePaddlePredictor<
2005-
Config, paddle::PaddleEngineKind::kAnalysis>(config);
2044+
predictor_ =
2045+
paddle::CreatePaddlePredictor<Config,
2046+
paddle::PaddleEngineKind::kAnalysis>(
2047+
config);
20062048
}
20072049

20082050
std::vector<std::string> Predictor::GetInputNames() {
@@ -2086,7 +2128,8 @@ std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
20862128
namespace services {
20872129
PredictorPool::PredictorPool(const Config &config, size_t size) {
20882130
PADDLE_ENFORCE_GE(
2089-
size, 1UL,
2131+
size,
2132+
1UL,
20902133
paddle::platform::errors::InvalidArgument(
20912134
"The predictor pool size should be greater than 1, but it's (%d)",
20922135
size));
@@ -2105,9 +2148,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) {
21052148

21062149
Predictor *PredictorPool::Retrive(size_t idx) {
21072150
PADDLE_ENFORCE_LT(
2108-
idx, preds_.size() + 1,
2151+
idx,
2152+
preds_.size() + 1,
21092153
paddle::platform::errors::InvalidArgument(
2110-
"There are (%d) predictors in the pool, but the idx is (%d)", idx,
2154+
"There are (%d) predictors in the pool, but the idx is (%d)",
2155+
idx,
21112156
preds_.size() + 1));
21122157
if (idx == 0) {
21132158
return main_pred_.get();

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ list(
6060
preln_skip_layernorm.cc
6161
roll_op.cc
6262
transformer_input_convert_op.cc
63+
cast_op.cc
6364
remove_padding_op.cc
6465
recover_padding_op.cc
6566
preln_residual_bias.cc

0 commit comments

Comments
 (0)