@@ -268,13 +268,14 @@ class TransformerGenerator(paddle.nn.Layer):
268268 max_out_len (int, optional):
269269 The maximum output length. Defaults to 256.
270270 kwargs:
271- The key word arguments can be `output_time_major` and `use_fp16_decoding `.
271+ The key word arguments can be `output_time_major`, `use_fp16_decoding` and `use_ft `.
272272 `output_time_major(bool, optional)`: Indicate the data layout of predicted
273273 Tensor. If `False`, the data layout would be batch major with shape
274274 `[batch_size, seq_len, beam_size]`. If `True`, the data layout would
275275 be time major with shape `[seq_len, batch_size, beam_size]`. Default
276276 to `False`. `use_fp16_decoding(bool, optional)`: Whether to use fp16
277- for decoding.
277+ for decoding. `use_ft(bool, optional)`: Whether to use Faster Transformer
278+ for decoding.
278279 """
279280
280281 def __init__ (self ,
@@ -303,25 +304,48 @@ def __init__(self,
303304 self .max_length = max_length
304305 self .output_time_major = kwargs .pop ("output_time_major" , True )
305306 use_fp16_decoding = kwargs .pop ("use_fp16_decoding" , False )
306- try :
307- load ("FasterTransformer" , verbose = True )
308- self .transformer = FasterTransformer (
309- src_vocab_size = src_vocab_size ,
310- trg_vocab_size = trg_vocab_size ,
311- max_length = max_length ,
312- num_encoder_layers = num_encoder_layers ,
313- num_decoder_layers = num_decoder_layers ,
314- n_head = n_head ,
315- d_model = d_model ,
316- d_inner_hid = d_inner_hid ,
317- dropout = dropout ,
318- weight_sharing = weight_sharing ,
319- bos_id = bos_id ,
320- eos_id = eos_id ,
321- beam_size = beam_size ,
322- max_out_len = max_out_len ,
323- use_fp16_decoding = use_fp16_decoding )
324- except Exception :
307+ use_ft = kwargs .pop ("use_ft" , True )
308+
309+ if use_ft :
310+ try :
311+ load ("FasterTransformer" , verbose = True )
312+ self .transformer = FasterTransformer (
313+ src_vocab_size = src_vocab_size ,
314+ trg_vocab_size = trg_vocab_size ,
315+ max_length = max_length ,
316+ num_encoder_layers = num_encoder_layers ,
317+ num_decoder_layers = num_decoder_layers ,
318+ n_head = n_head ,
319+ d_model = d_model ,
320+ d_inner_hid = d_inner_hid ,
321+ dropout = dropout ,
322+ weight_sharing = weight_sharing ,
323+ bos_id = bos_id ,
324+ eos_id = eos_id ,
325+ beam_size = beam_size ,
326+ max_out_len = max_out_len ,
327+ use_fp16_decoding = use_fp16_decoding )
328+ except Exception :
329+ logger .warning (
330+ "Exception occurs when using Faster Transformer. " \
331+ "The original forward will be involved. " )
332+ self .transformer = InferTransformerModel (
333+ src_vocab_size = src_vocab_size ,
334+ trg_vocab_size = trg_vocab_size ,
335+ max_length = max_length ,
336+ num_encoder_layers = num_encoder_layers ,
337+ num_decoder_layers = num_decoder_layers ,
338+ n_head = n_head ,
339+ d_model = d_model ,
340+ d_inner_hid = d_inner_hid ,
341+ dropout = dropout ,
342+ weight_sharing = weight_sharing ,
343+ bos_id = bos_id ,
344+ eos_id = eos_id ,
345+ beam_size = beam_size ,
346+ max_out_len = max_out_len ,
347+ output_time_major = self .output_time_major )
348+ else :
325349 self .transformer = InferTransformerModel (
326350 src_vocab_size = src_vocab_size ,
327351 trg_vocab_size = trg_vocab_size ,
0 commit comments