|
27 | 27 | import paddle.nn.functional as F |
28 | 28 |
|
29 | 29 | from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer |
30 | | -from paddlenlp.transformers import TinyBertModel, TinyBertForSequenceClassification, TinyBertTokenizer |
31 | | -from paddlenlp.transformers import TinyBertForSequenceClassification, TinyBertTokenizer |
32 | | -from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer |
33 | 30 | from paddlenlp.utils.log import logger |
34 | 31 | from paddleslim.nas.ofa import OFA, utils |
35 | 32 | from paddleslim.nas.ofa.convert_super import Convert, supernet |
36 | 33 | from paddleslim.nas.ofa.layers import BaseBlock |
37 | 34 |
|
38 | | -MODEL_CLASSES = { |
39 | | - "bert": (BertForSequenceClassification, BertTokenizer), |
40 | | - "roberta": (RobertaForSequenceClassification, RobertaTokenizer), |
41 | | - "tinybert": (TinyBertForSequenceClassification, TinyBertTokenizer), |
42 | | -} |
| 35 | +MODEL_CLASSES = {"bert": (BertForSequenceClassification, BertTokenizer), } |
43 | 36 |
|
44 | 37 |
|
45 | | -def tinybert_forward(self, input_ids, token_type_ids=None, attention_mask=None): |
| 38 | +def bert_forward(self, |
| 39 | + input_ids, |
| 40 | + token_type_ids=None, |
| 41 | + position_ids=None, |
| 42 | + attention_mask=None, |
| 43 | + output_hidden_states=False): |
46 | 44 | wtype = self.pooler.dense.fn.weight.dtype if hasattr( |
47 | 45 | self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype |
48 | 46 | if attention_mask is None: |
49 | 47 | attention_mask = paddle.unsqueeze( |
50 | 48 | (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2]) |
51 | | - embedding_output = self.embeddings(input_ids, token_type_ids) |
52 | | - encoded_layer = self.encoder(embedding_output, attention_mask) |
53 | | - pooled_output = self.pooler(encoded_layer) |
54 | | - |
55 | | - return encoded_layer, pooled_output |
| 49 | + else: |
| 50 | + if attention_mask.ndim == 2: |
| 51 | + # attention_mask [batch_size, sequence_length] -> [batch_size, 1, 1, sequence_length] |
| 52 | + attention_mask = attention_mask.unsqueeze(axis=[1, 2]) |
| 53 | + |
| 54 | + embedding_output = self.embeddings( |
| 55 | + input_ids=input_ids, |
| 56 | + position_ids=position_ids, |
| 57 | + token_type_ids=token_type_ids) |
| 58 | + if output_hidden_states: |
| 59 | + output = embedding_output |
| 60 | + encoder_outputs = [] |
| 61 | + for mod in self.encoder.layers: |
| 62 | + output = mod(output, src_mask=attention_mask) |
| 63 | + encoder_outputs.append(output) |
| 64 | + if self.encoder.norm is not None: |
| 65 | + encoder_outputs[-1] = self.encoder.norm(encoder_outputs[-1]) |
| 66 | + pooled_output = self.pooler(encoder_outputs[-1]) |
| 67 | + else: |
| 68 | + sequence_output = self.encoder(embedding_output, attention_mask) |
| 69 | + pooled_output = self.pooler(sequence_output) |
| 70 | + if output_hidden_states: |
| 71 | + return encoder_outputs, pooled_output |
| 72 | + else: |
| 73 | + return sequence_output, pooled_output |
56 | 74 |
|
57 | 75 |
|
58 | | -TinyBertModel.forward = tinybert_forward |
| 76 | +BertModel.forward = bert_forward |
59 | 77 |
|
60 | 78 |
|
61 | 79 | def parse_args(): |
|
0 commit comments