@@ -424,6 +424,8 @@ def forward(self, inputs, lengths):
424424
425425 # last_ids: batch_size
426426 scores , last_ids = alpha .max (1 ), alpha .argmax (1 )
427+ if max_seq_len == 1 :
428+ return scores , last_ids .unsqueeze (1 )
427429 # Trace back the best path
428430 # historys: seq_len, batch_size, n_labels
429431 historys = paddle .stack (historys )
@@ -438,10 +440,14 @@ def forward(self, inputs, lengths):
438440 # hist: batch_size, n_labels
439441 left_length = left_length + 1
440442 gather_idx = batch_offset + last_ids
441- tag_mask = paddle .cast ((left_length >= 0 ), 'int64' )
443+ tag_mask = paddle .cast ((left_length > 0 ), 'int64' )
442444 last_ids_update = paddle .gather (hist .flatten (),
443445 gather_idx ) * tag_mask
446+ zero_len_mask = paddle .cast ((left_length == 0 ), 'int64' )
447+ last_ids_update = last_ids_update * (1 - zero_len_mask
448+ ) + last_ids * zero_len_mask
444449 batch_path .append (last_ids_update )
450+ tag_mask = paddle .cast ((left_length >= 0 ), 'int64' )
445451 last_ids = last_ids_update + last_ids * (1 - tag_mask )
446452 batch_path = paddle .reverse (paddle .stack (batch_path , 1 ), [1 ])
447453 return scores , batch_path
0 commit comments