Skip to content

Commit 94f2651

Browse files
authored
Add force decoding in FasterTransformer (PaddlePaddle#873)
1 parent d436ab7 commit 94f2651

File tree

15 files changed

+1654
-304
lines changed

15 files changed

+1654
-304
lines changed

examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def parse_args():
4444
choices=["beam_search", "topk_sampling", "topp_sampling"],
4545
help="Decoding strategy. Can be one of ['beam_search', 'topk_sampling', 'topp_sampling']. "
4646
)
47-
parser.add_argument("--beam_size", default=5, type=int, help="Beam size. ")
47+
parser.add_argument("--beam_size", default=4, type=int, help="Beam size. ")
4848
parser.add_argument(
4949
"--diversity_rate",
5050
default=0.0,

examples/machine_translation/transformer/faster_transformer/export_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,12 @@ def do_predict(args):
129129
input_spec=[
130130
# src_word
131131
paddle.static.InputSpec(
132-
shape=[None, None], dtype="int64")
132+
shape=[None, None], dtype="int64"),
133+
# trg_word
134+
# Support exporting model which support force decoding
135+
# NOTE: Data type MUST be int32 !
136+
# paddle.static.InputSpec(
137+
# shape=[None, None], dtype="int32")
133138
])
134139

135140
# Save converted static graph model

paddlenlp/ops/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ if(NOT WITH_GPU)
3838
endif()
3939

4040
if(WITH_TRANSFORMER)
41-
list(APPEND decoding_op_files fusion_decoding_op.cc fusion_decoding_op.cu)
41+
list(APPEND decoding_op_files fusion_decoding_op.cc fusion_decoding_op.cu fusion_force_decoding_op.cc fusion_force_decoding_op.cu)
4242
endif()
4343

4444
if(WITH_GPT)

paddlenlp/ops/faster_transformer/sample/decoding_sample.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
from paddlenlp.utils.log import logger
2929

30+
paddle.seed(2)
31+
np.random.seed(2)
32+
3033

3134
def parse_args():
3235
parser = argparse.ArgumentParser()

paddlenlp/ops/faster_transformer/src/demo/transformer_e2e.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,15 @@ class DataReader {
183183
src_word_t->Reshape({batch_size, max_len});
184184
src_word_t->CopyFromCpu(src_word_vec.data());
185185

186+
// NOTE: If the saved model supports force decoding, a nullptr must be
187+
// given to trg_word to ensure predictor work properly when not
188+
// using force decoding.
189+
/*
190+
* auto trg_word_t = predictor->GetInputHandle("trg_word");
191+
* trg_word_t->Reshape({0, 0});
192+
* trg_word_t->CopyFromCpu((int*)nullptr);
193+
*/
194+
186195
return true;
187196
}
188197
};
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include <string>
15+
#include <vector>
16+
17+
#include "fusion_force_decoding_op.h"
18+
#include "pd_traits.h"
19+
20+
21+
std::vector<paddle::Tensor> DecodingForward(
22+
const paddle::Tensor& input,
23+
const paddle::Tensor& mem_seq_len,
24+
const paddle::Tensor& word_embedding,
25+
const std::vector<paddle::Tensor>& self_ln_weight,
26+
const std::vector<paddle::Tensor>& self_ln_bias,
27+
const std::vector<paddle::Tensor>& self_q_weight,
28+
const std::vector<paddle::Tensor>& self_q_bias,
29+
const std::vector<paddle::Tensor>& self_k_weight,
30+
const std::vector<paddle::Tensor>& self_k_bias,
31+
const std::vector<paddle::Tensor>& self_v_weight,
32+
const std::vector<paddle::Tensor>& self_v_bias,
33+
const std::vector<paddle::Tensor>& self_out_weight,
34+
const std::vector<paddle::Tensor>& self_out_bias,
35+
const std::vector<paddle::Tensor>& cross_ln_weight,
36+
const std::vector<paddle::Tensor>& cross_ln_bias,
37+
const std::vector<paddle::Tensor>& cross_q_weight,
38+
const std::vector<paddle::Tensor>& cross_q_bias,
39+
const std::vector<paddle::Tensor>& cross_k_weight,
40+
const std::vector<paddle::Tensor>& cross_k_bias,
41+
const std::vector<paddle::Tensor>& cross_v_weight,
42+
const std::vector<paddle::Tensor>& cross_v_bias,
43+
const std::vector<paddle::Tensor>& cross_out_weight,
44+
const std::vector<paddle::Tensor>& cross_out_bias,
45+
const std::vector<paddle::Tensor>& ffn_ln_weight,
46+
const std::vector<paddle::Tensor>& ffn_ln_bias,
47+
const std::vector<paddle::Tensor>& ffn_inter_weight,
48+
const std::vector<paddle::Tensor>& ffn_inter_bias,
49+
const std::vector<paddle::Tensor>& ffn_out_weight,
50+
const std::vector<paddle::Tensor>& ffn_out_bias,
51+
const paddle::Tensor& decoder_ln_weight,
52+
const paddle::Tensor& decoder_ln_bias,
53+
const paddle::Tensor& embedding_weight,
54+
const paddle::Tensor& embedding_bias,
55+
const paddle::Tensor& positional_embedding_weight,
56+
const paddle::Tensor& trg_word,
57+
const std::string& decoding_strategy,
58+
const int& beam_size,
59+
const int& topk,
60+
const float& topp,
61+
const int& n_head,
62+
const int& size_per_head,
63+
const int& num_layer,
64+
const int& bos_id,
65+
const int& eos_id,
66+
const int64_t& max_len,
67+
const float& beam_search_diversity_rate,
68+
const bool& rel_len,
69+
const float& alpha) {
70+
int batch_size = input.shape()[0];
71+
int max_out_len = rel_len ? max_len + input.shape()[1] : max_len;
72+
73+
std::vector<int64_t> output_dims;
74+
std::vector<int64_t> parent_ids_dims;
75+
std::vector<int64_t> sequence_length_dims({batch_size});
76+
if (decoding_strategy == "beam_search") {
77+
batch_size /= beam_size;
78+
output_dims = {max_out_len, batch_size, beam_size};
79+
parent_ids_dims = output_dims;
80+
} else if (decoding_strategy == "beam_search_v2") {
81+
// Use separated alive and finish beam queues to avoid the decrease of alive
82+
// beams. The outputs must include both the finish and alive to trace full
83+
// path.
84+
sequence_length_dims = {batch_size * 2};
85+
batch_size /= beam_size;
86+
output_dims = {max_out_len, batch_size, beam_size * 2};
87+
parent_ids_dims = output_dims;
88+
} else if (decoding_strategy == "sampling") {
89+
output_dims = {max_out_len, batch_size};
90+
parent_ids_dims = {1};
91+
} else {
92+
PD_THROW("Not supported decoding strategy. ");
93+
}
94+
95+
if (input.place() == paddle::PlaceType::kGPU) {
96+
auto output_ids = paddle::Tensor(paddle::PlaceType::kGPU, output_dims);
97+
auto parent_ids = paddle::Tensor(paddle::PlaceType::kGPU, parent_ids_dims);
98+
auto sequence_length =
99+
paddle::Tensor(paddle::PlaceType::kGPU, sequence_length_dims);
100+
101+
paddle::Tensor seq_len = paddle::Tensor(paddle::PlaceType::kGPU);
102+
103+
if (mem_seq_len.place() != paddle::PlaceType::kGPU) {
104+
seq_len = mem_seq_len.copy_to<int>(paddle::PlaceType::kGPU);
105+
} else {
106+
seq_len = mem_seq_len;
107+
}
108+
109+
return DecodingCUDAForward(input,
110+
seq_len,
111+
word_embedding,
112+
self_ln_weight,
113+
self_ln_bias,
114+
self_q_weight,
115+
self_q_bias,
116+
self_k_weight,
117+
self_k_bias,
118+
self_v_weight,
119+
self_v_bias,
120+
self_out_weight,
121+
self_out_bias,
122+
cross_ln_weight,
123+
cross_ln_bias,
124+
cross_q_weight,
125+
cross_q_bias,
126+
cross_k_weight,
127+
cross_k_bias,
128+
cross_v_weight,
129+
cross_v_bias,
130+
cross_out_weight,
131+
cross_out_bias,
132+
ffn_ln_weight,
133+
ffn_ln_bias,
134+
ffn_inter_weight,
135+
ffn_inter_bias,
136+
ffn_out_weight,
137+
ffn_out_bias,
138+
decoder_ln_weight,
139+
decoder_ln_bias,
140+
embedding_weight,
141+
embedding_bias,
142+
positional_embedding_weight,
143+
trg_word,
144+
output_ids,
145+
parent_ids,
146+
sequence_length,
147+
decoding_strategy,
148+
beam_size,
149+
topk,
150+
topp,
151+
n_head,
152+
size_per_head,
153+
num_layer,
154+
bos_id,
155+
eos_id,
156+
max_out_len,
157+
beam_search_diversity_rate,
158+
alpha);
159+
} else {
160+
PD_THROW("Not implemented place. Only GPU is supported. ");
161+
}
162+
}
163+
164+
std::vector<std::vector<int64_t>> DecodingInferShape(
165+
const std::vector<int64_t>& input_shape,
166+
const std::vector<int64_t>& mem_seq_len_shape,
167+
const std::vector<int64_t>& word_embedding_shape,
168+
const std::vector<std::vector<int64_t>>& self_ln_weight_shapes,
169+
const std::vector<std::vector<int64_t>>& self_ln_bias_shapes,
170+
const std::vector<std::vector<int64_t>>& self_q_weight_shapes,
171+
const std::vector<std::vector<int64_t>>& self_q_bias_shapes,
172+
const std::vector<std::vector<int64_t>>& self_k_weight_shapes,
173+
const std::vector<std::vector<int64_t>>& self_k_bias_shapes,
174+
const std::vector<std::vector<int64_t>>& self_v_weight_shapes,
175+
const std::vector<std::vector<int64_t>>& self_v_bias_shapes,
176+
const std::vector<std::vector<int64_t>>& self_out_weight_shapes,
177+
const std::vector<std::vector<int64_t>>& self_out_bias_shapes,
178+
const std::vector<std::vector<int64_t>>& cross_ln_weight_shapes,
179+
const std::vector<std::vector<int64_t>>& cross_ln_bias_shapes,
180+
const std::vector<std::vector<int64_t>>& cross_q_weight_shapes,
181+
const std::vector<std::vector<int64_t>>& cross_q_bias_shapes,
182+
const std::vector<std::vector<int64_t>>& cross_k_weight_shapes,
183+
const std::vector<std::vector<int64_t>>& cross_k_bias_shapes,
184+
const std::vector<std::vector<int64_t>>& cross_v_weight_shapes,
185+
const std::vector<std::vector<int64_t>>& cross_v_bias_shapes,
186+
const std::vector<std::vector<int64_t>>& cross_out_weight_shapes,
187+
const std::vector<std::vector<int64_t>>& cross_out_bias_shapes,
188+
const std::vector<std::vector<int64_t>>& ffn_ln_weight_shapes,
189+
const std::vector<std::vector<int64_t>>& ffn_ln_bias_shapes,
190+
const std::vector<std::vector<int64_t>>& ffn_inter_weight_shapes,
191+
const std::vector<std::vector<int64_t>>& ffn_inter_bias_shapes,
192+
const std::vector<std::vector<int64_t>>& ffn_out_weight_shapes,
193+
const std::vector<std::vector<int64_t>>& ffn_out_bias_shapes,
194+
const std::vector<int64_t>& decoder_ln_weight_shape,
195+
const std::vector<int64_t>& decoder_ln_bias_shape,
196+
const std::vector<int64_t>& embedding_weight_shape,
197+
const std::vector<int64_t>& embedding_bias_shape,
198+
const std::vector<int64_t>& positional_embedding_weight_shape,
199+
const std::vector<int64_t>& trg_word_shape,
200+
const std::string& decoding_strategy,
201+
const int& beam_size,
202+
const int& topk,
203+
const float& topp,
204+
const int& n_head,
205+
const int& size_per_head,
206+
const int& num_layer,
207+
const int& bos_id,
208+
const int& eos_id,
209+
const int64_t& max_len,
210+
const float& beam_search_diversity_rate,
211+
const bool& rel_len,
212+
const float& alpha) {
213+
int batch_size = input_shape[0];
214+
215+
std::vector<int64_t> output_dims;
216+
std::vector<int64_t> sequence_length_dims({batch_size});
217+
if (decoding_strategy == "beam_search") {
218+
if (batch_size != -1) {
219+
batch_size /= beam_size;
220+
}
221+
output_dims = {max_len, batch_size, beam_size};
222+
return {output_dims, output_dims, sequence_length_dims};
223+
} else if (decoding_strategy == "beam_search_v2") {
224+
// Use separated alive and finish beam queues to avoid the decrease of alive
225+
// beams. The outputs must include both the finish and alive to trace full
226+
// path.
227+
sequence_length_dims = {batch_size * 2};
228+
if (batch_size != -1) {
229+
batch_size /= beam_size;
230+
}
231+
output_dims = {max_len, batch_size, beam_size * 2};
232+
return {output_dims, output_dims, sequence_length_dims};
233+
} else if (decoding_strategy == "sampling") {
234+
output_dims = {max_len, batch_size};
235+
return {output_dims, {1}, sequence_length_dims};
236+
} else {
237+
PD_THROW("Not supported decoding strategy. ");
238+
}
239+
}
240+
241+
std::vector<paddle::DataType> DecodingInferDtype(
242+
const paddle::DataType& input,
243+
const paddle::DataType& mem_seq_len,
244+
const paddle::DataType& word_embedding,
245+
const std::vector<paddle::DataType>& self_ln_weight,
246+
const std::vector<paddle::DataType>& self_ln_bias,
247+
const std::vector<paddle::DataType>& self_q_weight,
248+
const std::vector<paddle::DataType>& self_q_bias,
249+
const std::vector<paddle::DataType>& self_k_weight,
250+
const std::vector<paddle::DataType>& self_k_bias,
251+
const std::vector<paddle::DataType>& self_v_weight,
252+
const std::vector<paddle::DataType>& self_v_bias,
253+
const std::vector<paddle::DataType>& self_out_weight,
254+
const std::vector<paddle::DataType>& self_out_bias,
255+
const std::vector<paddle::DataType>& cross_ln_weight,
256+
const std::vector<paddle::DataType>& cross_ln_bias,
257+
const std::vector<paddle::DataType>& cross_q_weight,
258+
const std::vector<paddle::DataType>& cross_q_bias,
259+
const std::vector<paddle::DataType>& cross_k_weight,
260+
const std::vector<paddle::DataType>& cross_k_bias,
261+
const std::vector<paddle::DataType>& cross_v_weight,
262+
const std::vector<paddle::DataType>& cross_v_bias,
263+
const std::vector<paddle::DataType>& cross_out_weight,
264+
const std::vector<paddle::DataType>& cross_out_bias,
265+
const std::vector<paddle::DataType>& ffn_ln_weight,
266+
const std::vector<paddle::DataType>& ffn_ln_bias,
267+
const std::vector<paddle::DataType>& ffn_inter_weight,
268+
const std::vector<paddle::DataType>& ffn_inter_bias,
269+
const std::vector<paddle::DataType>& ffn_out_weight,
270+
const std::vector<paddle::DataType>& ffn_out_bias,
271+
const paddle::DataType& decoder_ln_weight,
272+
const paddle::DataType& decoder_ln_bias,
273+
const paddle::DataType& embedding_weight,
274+
const paddle::DataType& embedding_bias,
275+
const paddle::DataType& positional_embedding_weight,
276+
const paddle::DataType& trg_word) {
277+
return {paddle::DataType::INT32,
278+
paddle::DataType::INT32,
279+
paddle::DataType::INT32};
280+
}
281+
282+
PD_BUILD_OP(fusion_force_decoding)
283+
.Inputs({"Input",
284+
"MemSeqLen",
285+
"WordEmbedding",
286+
paddle::Vec("SelfLayernormWeight"),
287+
paddle::Vec("SelfLayernormBias"),
288+
paddle::Vec("SelfQueryWeight"),
289+
paddle::Vec("SelfQueryBias"),
290+
paddle::Vec("SelfKeyWeight"),
291+
paddle::Vec("SelfKeyBias"),
292+
paddle::Vec("SelfValueWeight"),
293+
paddle::Vec("SelfValueBias"),
294+
paddle::Vec("SelfOutWeight"),
295+
paddle::Vec("SelfOutBias"),
296+
paddle::Vec("CrossLayernormWeight"),
297+
paddle::Vec("CrossLayernormBias"),
298+
paddle::Vec("CrossQueryWeight"),
299+
paddle::Vec("CrossQueryBias"),
300+
paddle::Vec("CrossKeyWeight"),
301+
paddle::Vec("CrossKeyBias"),
302+
paddle::Vec("CrossValueWeight"),
303+
paddle::Vec("CrossValueBias"),
304+
paddle::Vec("CrossOutWeight"),
305+
paddle::Vec("CrossOutBias"),
306+
paddle::Vec("FFNLayernormWeight"),
307+
paddle::Vec("FFNLayernormBias"),
308+
paddle::Vec("FFNInterWeight"),
309+
paddle::Vec("FFNInterBias"),
310+
paddle::Vec("FFNOutWeight"),
311+
paddle::Vec("FFNOutBias"),
312+
"DecoderLayernormWeight",
313+
"DecoderLayernormBias",
314+
"EmbWeight",
315+
"EmbBias",
316+
"PositionEncEmb",
317+
"TrgWord"})
318+
.Outputs({"OutputIds", "ParentIds", "SequenceLength"})
319+
.Attrs({"decoding_strategy: std::string",
320+
"beam_size: int",
321+
"topk: int",
322+
"topp: float",
323+
"n_head: int",
324+
"size_per_head: int",
325+
"num_layer: int",
326+
"bos_id: int",
327+
"eos_id: int",
328+
"max_len: int64_t",
329+
"beam_search_diversity_rate: float",
330+
"rel_len: bool",
331+
"alpha: float"})
332+
.SetKernelFn(PD_KERNEL(DecodingForward))
333+
.SetInferShapeFn(PD_INFER_SHAPE(DecodingInferShape))
334+
.SetInferDtypeFn(PD_INFER_DTYPE(DecodingInferDtype));

0 commit comments

Comments
 (0)