@@ -52,32 +52,81 @@ StatusOr<std::unique_ptr<BertPreprocessor>> BertPreprocessor::Create(
5252 return processor;
5353}
5454
55+ // TODO(b/241507692) Add a unit test for a model with dynamic tensors.
5556absl::Status BertPreprocessor::Init () {
56- // Try if RegexTokenzier can be found.
57- // BertTokenzier is packed in the processing unit SubgraphMetadata.
57+ // Try if RegexTokenizer can be found.
58+ // BertTokenizer is packed in the processing unit SubgraphMetadata.
5859 const tflite::ProcessUnit* tokenizer_metadata =
5960 GetMetadataExtractor ()->GetInputProcessUnit (kTokenizerProcessUnitIndex );
6061 ASSIGN_OR_RETURN (tokenizer_, CreateTokenizerFromProcessUnit (
6162 tokenizer_metadata, GetMetadataExtractor ()));
6263
63- // Sanity check and assign max sequence length.
64- if ( GetLastDimSize (tensor_indices_[ kIdsTensorIndex ]) !=
65- GetLastDimSize (tensor_indices_[ kMaskTensorIndex ]) ||
66- GetLastDimSize (tensor_indices_[ kIdsTensorIndex ]) !=
67- GetLastDimSize (tensor_indices_[ kSegmentIdsTensorIndex ]) ) {
64+ const auto & ids_tensor = * GetTensor ( kIdsTensorIndex );
65+ const auto & mask_tensor = * GetTensor ( kMaskTensorIndex );
66+ const auto & segment_ids_tensor = * GetTensor ( kSegmentIdsTensorIndex );
67+ if (ids_tensor. dims -> size != 2 || mask_tensor. dims -> size != 2 ||
68+ segment_ids_tensor. dims -> size != 2 ) {
6869 return CreateStatusWithPayload (
6970 absl::StatusCode::kInternal ,
7071 absl::StrFormat (
71- " The three input tensors in Bert models are "
72- " expected to have same length, but got ids_tensor "
73- " (%d), mask_tensor (%d), segment_ids_tensor (%d)." ,
74- GetLastDimSize (tensor_indices_[kIdsTensorIndex ]),
75- GetLastDimSize (tensor_indices_[kMaskTensorIndex ]),
76- GetLastDimSize (tensor_indices_[kSegmentIdsTensorIndex ])),
77- TfLiteSupportStatus::kInvalidNumOutputTensorsError );
72+ " The three input tensors in Bert models are expected to have dim "
73+ " 2, but got ids_tensor (%d), mask_tensor (%d), segment_ids_tensor "
74+ " (%d)." ,
75+ ids_tensor.dims ->size , mask_tensor.dims ->size ,
76+ segment_ids_tensor.dims ->size ),
77+ TfLiteSupportStatus::kInvalidInputTensorDimensionsError );
78+ }
79+ if (ids_tensor.dims ->data [0 ] != 1 || mask_tensor.dims ->data [0 ] != 1 ||
80+ segment_ids_tensor.dims ->data [0 ] != 1 ) {
81+ return CreateStatusWithPayload (
82+ absl::StatusCode::kInternal ,
83+ absl::StrFormat (
84+ " The three input tensors in Bert models are expected to have same "
85+ " batch size 1, but got ids_tensor (%d), mask_tensor (%d), "
86+ " segment_ids_tensor (%d)." ,
87+ ids_tensor.dims ->data [0 ], mask_tensor.dims ->data [0 ],
88+ segment_ids_tensor.dims ->data [0 ]),
89+ TfLiteSupportStatus::kInvalidInputTensorSizeError );
90+ }
91+ if (ids_tensor.dims ->data [1 ] != mask_tensor.dims ->data [1 ] ||
92+ ids_tensor.dims ->data [1 ] != segment_ids_tensor.dims ->data [1 ]) {
93+ return CreateStatusWithPayload (
94+ absl::StatusCode::kInternal ,
95+ absl::StrFormat (" The three input tensors in Bert models are "
96+ " expected to have same length, but got ids_tensor "
97+ " (%d), mask_tensor (%d), segment_ids_tensor (%d)." ,
98+ ids_tensor.dims ->data [1 ], mask_tensor.dims ->data [1 ],
99+ segment_ids_tensor.dims ->data [1 ]),
100+ TfLiteSupportStatus::kInvalidInputTensorSizeError );
78101 }
79- bert_max_seq_len_ = GetLastDimSize (tensor_indices_[kIdsTensorIndex ]);
80102
103+ bool has_valid_dims_signature = ids_tensor.dims_signature ->size == 2 &&
104+ mask_tensor.dims_signature ->size == 2 &&
105+ segment_ids_tensor.dims_signature ->size == 2 ;
106+ if (has_valid_dims_signature && ids_tensor.dims_signature ->data [1 ] == -1 &&
107+ mask_tensor.dims_signature ->data [1 ] == -1 &&
108+ segment_ids_tensor.dims_signature ->data [1 ] == -1 ) {
109+ input_tensors_are_dynamic_ = true ;
110+ } else if (has_valid_dims_signature &&
111+ (ids_tensor.dims_signature ->data [1 ] == -1 ||
112+ mask_tensor.dims_signature ->data [1 ] == -1 ||
113+ segment_ids_tensor.dims_signature ->data [1 ] == -1 )) {
114+ return CreateStatusWithPayload (
115+ absl::StatusCode::kInternal ,
116+ " Input tensors contain a mix of static and dynamic tensors" ,
117+ TfLiteSupportStatus::kInvalidInputTensorSizeError );
118+ }
119+
120+ if (input_tensors_are_dynamic_) return absl::OkStatus ();
121+
122+ bert_max_seq_len_ = ids_tensor.dims ->data [1 ];
123+ if (bert_max_seq_len_ < 2 ) {
124+ return CreateStatusWithPayload (
125+ absl::StatusCode::kInternal ,
126+ absl::StrFormat (" bert_max_seq_len_ should be at least 2, got: (%d)." ,
127+ bert_max_seq_len_),
128+ TfLiteSupportStatus::kInvalidInputTensorSizeError );
129+ }
81130 return absl::OkStatus ();
82131}
83132
@@ -92,48 +141,50 @@ absl::Status BertPreprocessor::Preprocess(const std::string& input_text) {
92141 TokenizerResult input_tokenize_results;
93142 input_tokenize_results = tokenizer_->Tokenize (processed_input);
94143
95- // 2 accounts for [CLS], [SEP]
96- absl::Span<const std::string> query_tokens =
97- absl::MakeSpan (input_tokenize_results.subwords .data (),
98- input_tokenize_results.subwords .data () +
99- std::min (static_cast <size_t >(bert_max_seq_len_ - 2 ),
100- input_tokenize_results.subwords .size ()));
101-
102- std::vector<std::string> tokens;
103- tokens.reserve (2 + query_tokens.size ());
104- // Start of generating the features.
105- tokens.push_back (kClassificationToken );
106- // For query input.
107- for (const auto & query_token : query_tokens) {
108- tokens.push_back (query_token);
144+ // Offset by 2 to account for [CLS] and [SEP]
145+ int input_tokens_size =
146+ static_cast <int >(input_tokenize_results.subwords .size ()) + 2 ;
147+ int input_tensor_length = input_tokens_size;
148+ if (!input_tensors_are_dynamic_) {
149+ input_tokens_size = std::min (bert_max_seq_len_, input_tokens_size);
150+ input_tensor_length = bert_max_seq_len_;
151+ } else {
152+ engine_->interpreter ()->ResizeInputTensorStrict (kIdsTensorIndex ,
153+ {1 , input_tensor_length});
154+ engine_->interpreter ()->ResizeInputTensorStrict (kMaskTensorIndex ,
155+ {1 , input_tensor_length});
156+ engine_->interpreter ()->ResizeInputTensorStrict (kSegmentIdsTensorIndex ,
157+ {1 , input_tensor_length});
158+ engine_->interpreter ()->AllocateTensors ();
109159 }
110- // For Separation.
111- tokens.push_back (kSeparator );
112160
113- std::vector<int > input_ids (bert_max_seq_len_, 0 );
114- std::vector<int > input_mask (bert_max_seq_len_, 0 );
161+ std::vector<std::string> input_tokens;
162+ input_tokens.reserve (input_tokens_size);
163+ input_tokens.push_back (std::string (kClassificationToken ));
164+ for (int i = 0 ; i < input_tokens_size - 2 ; ++i) {
165+ input_tokens.push_back (std::move (input_tokenize_results.subwords [i]));
166+ }
167+ input_tokens.push_back (std::string (kSeparator ));
168+
169+ std::vector<int > input_ids (input_tensor_length, 0 );
170+ std::vector<int > input_mask (input_tensor_length, 0 );
115171 // Convert tokens back into ids and set mask
116- for (int i = 0 ; i < tokens .size (); ++i) {
117- tokenizer_->LookupId (tokens [i], &input_ids[i]);
172+ for (int i = 0 ; i < input_tokens .size (); ++i) {
173+ tokenizer_->LookupId (input_tokens [i], &input_ids[i]);
118174 input_mask[i] = 1 ;
119175 }
120- // |<--------bert_max_seq_len_-- ------->|
176+ // |<--------input_tensor_length ------->|
121177 // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
122178 // input_masks 1 1 1... 1 1 0 0... 0
123179 // segment_ids 0 0 0... 0 0 0 0... 0
124180
125181 RETURN_IF_ERROR (PopulateTensor (input_ids, ids_tensor));
126182 RETURN_IF_ERROR (PopulateTensor (input_mask, mask_tensor));
127- RETURN_IF_ERROR (PopulateTensor (std::vector<int >(bert_max_seq_len_ , 0 ),
183+ RETURN_IF_ERROR (PopulateTensor (std::vector<int >(input_tensor_length , 0 ),
128184 segment_ids_tensor));
129185 return absl::OkStatus ();
130186}
131187
132- int BertPreprocessor::GetLastDimSize (int tensor_index) {
133- auto tensor = engine_->GetInput (engine_->interpreter (), tensor_index);
134- return tensor->dims ->data [tensor->dims ->size - 1 ];
135- }
136-
137188} // namespace processor
138189} // namespace task
139190} // namespace tflite
0 commit comments