Skip to content

Commit f9af064

Browse files
authored
Make FT selectable (PaddlePaddle#826)
* Make FT selectable * update * fix comments
1 parent 84257cd commit f9af064

File tree

2 files changed

+52
-22
lines changed

2 files changed

+52
-22
lines changed

examples/machine_translation/transformer/predict.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def parse_args():
3030
type=str,
3131
help="The file for testing. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to process testing."
3232
)
33+
parser.add_argument(
34+
"--without_ft",
35+
action="store_true",
36+
help="Whether to use Faster Transformer to do predict. ")
3337
args = parser.parse_args()
3438
return args
3539

@@ -78,7 +82,8 @@ def do_predict(args):
7882
bos_id=args.bos_idx,
7983
eos_id=args.eos_idx,
8084
beam_size=args.beam_size,
81-
max_out_len=args.max_out_len)
85+
max_out_len=args.max_out_len,
86+
use_ft=not args.without_ft)
8287

8388
# Load the trained model
8489
assert args.init_from_params, (
@@ -114,6 +119,7 @@ def do_predict(args):
114119
args = AttrDict(yaml.safe_load(f))
115120
args.benchmark = ARGS.benchmark
116121
args.test_file = ARGS.test_file
122+
args.without_ft = ARGS.without_ft
117123
pprint(args)
118124

119125
do_predict(args)

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,14 @@ class TransformerGenerator(paddle.nn.Layer):
268268
max_out_len (int, optional):
269269
The maximum output length. Defaults to 256.
270270
kwargs:
271-
The key word arguments can be `output_time_major` and `use_fp16_decoding`.
271+
The key word arguments can be `output_time_major`, `use_fp16_decoding` and `use_ft`.
272272
`output_time_major(bool, optional)`: Indicate the data layout of predicted
273273
Tensor. If `False`, the data layout would be batch major with shape
274274
`[batch_size, seq_len, beam_size]`. If `True`, the data layout would
275275
be time major with shape `[seq_len, batch_size, beam_size]`. Default
276276
to `False`. `use_fp16_decoding(bool, optional)`: Whether to use fp16
277-
for decoding.
277+
for decoding. `use_ft(bool, optional)`: Whether to use Faster Transformer
278+
for decoding.
278279
"""
279280

280281
def __init__(self,
@@ -303,25 +304,48 @@ def __init__(self,
303304
self.max_length = max_length
304305
self.output_time_major = kwargs.pop("output_time_major", True)
305306
use_fp16_decoding = kwargs.pop("use_fp16_decoding", False)
306-
try:
307-
load("FasterTransformer", verbose=True)
308-
self.transformer = FasterTransformer(
309-
src_vocab_size=src_vocab_size,
310-
trg_vocab_size=trg_vocab_size,
311-
max_length=max_length,
312-
num_encoder_layers=num_encoder_layers,
313-
num_decoder_layers=num_decoder_layers,
314-
n_head=n_head,
315-
d_model=d_model,
316-
d_inner_hid=d_inner_hid,
317-
dropout=dropout,
318-
weight_sharing=weight_sharing,
319-
bos_id=bos_id,
320-
eos_id=eos_id,
321-
beam_size=beam_size,
322-
max_out_len=max_out_len,
323-
use_fp16_decoding=use_fp16_decoding)
324-
except Exception:
307+
use_ft = kwargs.pop("use_ft", True)
308+
309+
if use_ft:
310+
try:
311+
load("FasterTransformer", verbose=True)
312+
self.transformer = FasterTransformer(
313+
src_vocab_size=src_vocab_size,
314+
trg_vocab_size=trg_vocab_size,
315+
max_length=max_length,
316+
num_encoder_layers=num_encoder_layers,
317+
num_decoder_layers=num_decoder_layers,
318+
n_head=n_head,
319+
d_model=d_model,
320+
d_inner_hid=d_inner_hid,
321+
dropout=dropout,
322+
weight_sharing=weight_sharing,
323+
bos_id=bos_id,
324+
eos_id=eos_id,
325+
beam_size=beam_size,
326+
max_out_len=max_out_len,
327+
use_fp16_decoding=use_fp16_decoding)
328+
except Exception:
329+
logger.warning(
330+
"Exception occurs when using Faster Transformer. " \
331+
"The original forward will be involved. ")
332+
self.transformer = InferTransformerModel(
333+
src_vocab_size=src_vocab_size,
334+
trg_vocab_size=trg_vocab_size,
335+
max_length=max_length,
336+
num_encoder_layers=num_encoder_layers,
337+
num_decoder_layers=num_decoder_layers,
338+
n_head=n_head,
339+
d_model=d_model,
340+
d_inner_hid=d_inner_hid,
341+
dropout=dropout,
342+
weight_sharing=weight_sharing,
343+
bos_id=bos_id,
344+
eos_id=eos_id,
345+
beam_size=beam_size,
346+
max_out_len=max_out_len,
347+
output_time_major=self.output_time_major)
348+
else:
325349
self.transformer = InferTransformerModel(
326350
src_vocab_size=src_vocab_size,
327351
trg_vocab_size=trg_vocab_size,

0 commit comments

Comments
 (0)