Skip to content

Commit 527be55

Browse files
Internal change
PiperOrigin-RevId: 467457993
1 parent dd36717 commit 527be55

File tree

2 files changed

+98
-45
lines changed

2 files changed

+98
-45
lines changed

tensorflow_lite_support/cc/task/processor/bert_preprocessor.cc

Lines changed: 93 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -52,32 +52,81 @@ StatusOr<std::unique_ptr<BertPreprocessor>> BertPreprocessor::Create(
5252
return processor;
5353
}
5454

55+
// TODO(b/241507692) Add a unit test for a model with dynamic tensors.
5556
absl::Status BertPreprocessor::Init() {
56-
// Try if RegexTokenzier can be found.
57-
// BertTokenzier is packed in the processing unit SubgraphMetadata.
57+
// Try if RegexTokenizer can be found.
58+
// BertTokenizer is packed in the processing unit SubgraphMetadata.
5859
const tflite::ProcessUnit* tokenizer_metadata =
5960
GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
6061
ASSIGN_OR_RETURN(tokenizer_, CreateTokenizerFromProcessUnit(
6162
tokenizer_metadata, GetMetadataExtractor()));
6263

63-
// Sanity check and assign max sequence length.
64-
if (GetLastDimSize(tensor_indices_[kIdsTensorIndex]) !=
65-
GetLastDimSize(tensor_indices_[kMaskTensorIndex]) ||
66-
GetLastDimSize(tensor_indices_[kIdsTensorIndex]) !=
67-
GetLastDimSize(tensor_indices_[kSegmentIdsTensorIndex])) {
64+
const auto& ids_tensor = *GetTensor(kIdsTensorIndex);
65+
const auto& mask_tensor = *GetTensor(kMaskTensorIndex);
66+
const auto& segment_ids_tensor = *GetTensor(kSegmentIdsTensorIndex);
67+
if (ids_tensor.dims->size != 2 || mask_tensor.dims->size != 2 ||
68+
segment_ids_tensor.dims->size != 2) {
6869
return CreateStatusWithPayload(
6970
absl::StatusCode::kInternal,
7071
absl::StrFormat(
71-
"The three input tensors in Bert models are "
72-
"expected to have same length, but got ids_tensor "
73-
"(%d), mask_tensor (%d), segment_ids_tensor (%d).",
74-
GetLastDimSize(tensor_indices_[kIdsTensorIndex]),
75-
GetLastDimSize(tensor_indices_[kMaskTensorIndex]),
76-
GetLastDimSize(tensor_indices_[kSegmentIdsTensorIndex])),
77-
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
72+
"The three input tensors in Bert models are expected to have dim "
73+
"2, but got ids_tensor (%d), mask_tensor (%d), segment_ids_tensor "
74+
"(%d).",
75+
ids_tensor.dims->size, mask_tensor.dims->size,
76+
segment_ids_tensor.dims->size),
77+
TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
78+
}
79+
if (ids_tensor.dims->data[0] != 1 || mask_tensor.dims->data[0] != 1 ||
80+
segment_ids_tensor.dims->data[0] != 1) {
81+
return CreateStatusWithPayload(
82+
absl::StatusCode::kInternal,
83+
absl::StrFormat(
84+
"The three input tensors in Bert models are expected to have same "
85+
"batch size 1, but got ids_tensor (%d), mask_tensor (%d), "
86+
"segment_ids_tensor (%d).",
87+
ids_tensor.dims->data[0], mask_tensor.dims->data[0],
88+
segment_ids_tensor.dims->data[0]),
89+
TfLiteSupportStatus::kInvalidInputTensorSizeError);
90+
}
91+
if (ids_tensor.dims->data[1] != mask_tensor.dims->data[1] ||
92+
ids_tensor.dims->data[1] != segment_ids_tensor.dims->data[1]) {
93+
return CreateStatusWithPayload(
94+
absl::StatusCode::kInternal,
95+
absl::StrFormat("The three input tensors in Bert models are "
96+
"expected to have same length, but got ids_tensor "
97+
"(%d), mask_tensor (%d), segment_ids_tensor (%d).",
98+
ids_tensor.dims->data[1], mask_tensor.dims->data[1],
99+
segment_ids_tensor.dims->data[1]),
100+
TfLiteSupportStatus::kInvalidInputTensorSizeError);
78101
}
79-
bert_max_seq_len_ = GetLastDimSize(tensor_indices_[kIdsTensorIndex]);
80102

103+
bool has_valid_dims_signature = ids_tensor.dims_signature->size == 2 &&
104+
mask_tensor.dims_signature->size == 2 &&
105+
segment_ids_tensor.dims_signature->size == 2;
106+
if (has_valid_dims_signature && ids_tensor.dims_signature->data[1] == -1 &&
107+
mask_tensor.dims_signature->data[1] == -1 &&
108+
segment_ids_tensor.dims_signature->data[1] == -1) {
109+
input_tensors_are_dynamic_ = true;
110+
} else if (has_valid_dims_signature &&
111+
(ids_tensor.dims_signature->data[1] == -1 ||
112+
mask_tensor.dims_signature->data[1] == -1 ||
113+
segment_ids_tensor.dims_signature->data[1] == -1)) {
114+
return CreateStatusWithPayload(
115+
absl::StatusCode::kInternal,
116+
"Input tensors contain a mix of static and dynamic tensors",
117+
TfLiteSupportStatus::kInvalidInputTensorSizeError);
118+
}
119+
120+
if (input_tensors_are_dynamic_) return absl::OkStatus();
121+
122+
bert_max_seq_len_ = ids_tensor.dims->data[1];
123+
if (bert_max_seq_len_ < 2) {
124+
return CreateStatusWithPayload(
125+
absl::StatusCode::kInternal,
126+
absl::StrFormat("bert_max_seq_len_ should be at least 2, got: (%d).",
127+
bert_max_seq_len_),
128+
TfLiteSupportStatus::kInvalidInputTensorSizeError);
129+
}
81130
return absl::OkStatus();
82131
}
83132

@@ -92,48 +141,50 @@ absl::Status BertPreprocessor::Preprocess(const std::string& input_text) {
92141
TokenizerResult input_tokenize_results;
93142
input_tokenize_results = tokenizer_->Tokenize(processed_input);
94143

95-
// 2 accounts for [CLS], [SEP]
96-
absl::Span<const std::string> query_tokens =
97-
absl::MakeSpan(input_tokenize_results.subwords.data(),
98-
input_tokenize_results.subwords.data() +
99-
std::min(static_cast<size_t>(bert_max_seq_len_ - 2),
100-
input_tokenize_results.subwords.size()));
101-
102-
std::vector<std::string> tokens;
103-
tokens.reserve(2 + query_tokens.size());
104-
// Start of generating the features.
105-
tokens.push_back(kClassificationToken);
106-
// For query input.
107-
for (const auto& query_token : query_tokens) {
108-
tokens.push_back(query_token);
144+
// Offset by 2 to account for [CLS] and [SEP]
145+
int input_tokens_size =
146+
static_cast<int>(input_tokenize_results.subwords.size()) + 2;
147+
int input_tensor_length = input_tokens_size;
148+
if (!input_tensors_are_dynamic_) {
149+
input_tokens_size = std::min(bert_max_seq_len_, input_tokens_size);
150+
input_tensor_length = bert_max_seq_len_;
151+
} else {
152+
engine_->interpreter()->ResizeInputTensorStrict(kIdsTensorIndex,
153+
{1, input_tensor_length});
154+
engine_->interpreter()->ResizeInputTensorStrict(kMaskTensorIndex,
155+
{1, input_tensor_length});
156+
engine_->interpreter()->ResizeInputTensorStrict(kSegmentIdsTensorIndex,
157+
{1, input_tensor_length});
158+
engine_->interpreter()->AllocateTensors();
109159
}
110-
// For Separation.
111-
tokens.push_back(kSeparator);
112160

113-
std::vector<int> input_ids(bert_max_seq_len_, 0);
114-
std::vector<int> input_mask(bert_max_seq_len_, 0);
161+
std::vector<std::string> input_tokens;
162+
input_tokens.reserve(input_tokens_size);
163+
input_tokens.push_back(std::string(kClassificationToken));
164+
for (int i = 0; i < input_tokens_size - 2; ++i) {
165+
input_tokens.push_back(std::move(input_tokenize_results.subwords[i]));
166+
}
167+
input_tokens.push_back(std::string(kSeparator));
168+
169+
std::vector<int> input_ids(input_tensor_length, 0);
170+
std::vector<int> input_mask(input_tensor_length, 0);
115171
// Convert tokens back into ids and set mask
116-
for (int i = 0; i < tokens.size(); ++i) {
117-
tokenizer_->LookupId(tokens[i], &input_ids[i]);
172+
for (int i = 0; i < input_tokens.size(); ++i) {
173+
tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
118174
input_mask[i] = 1;
119175
}
120-
// |<--------bert_max_seq_len_--------->|
176+
// |<--------input_tensor_length------->|
121177
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
122178
// input_masks 1 1 1... 1 1 0 0... 0
123179
// segment_ids 0 0 0... 0 0 0 0... 0
124180

125181
RETURN_IF_ERROR(PopulateTensor(input_ids, ids_tensor));
126182
RETURN_IF_ERROR(PopulateTensor(input_mask, mask_tensor));
127-
RETURN_IF_ERROR(PopulateTensor(std::vector<int>(bert_max_seq_len_, 0),
183+
RETURN_IF_ERROR(PopulateTensor(std::vector<int>(input_tensor_length, 0),
128184
segment_ids_tensor));
129185
return absl::OkStatus();
130186
}
131187

132-
int BertPreprocessor::GetLastDimSize(int tensor_index) {
133-
auto tensor = engine_->GetInput(engine_->interpreter(), tensor_index);
134-
return tensor->dims->data[tensor->dims->size - 1];
135-
}
136-
137188
} // namespace processor
138189
} // namespace task
139190
} // namespace tflite

tensorflow_lite_support/cc/task/processor/bert_preprocessor.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ class BertPreprocessor : public TextPreprocessor {
4646

4747
absl::Status Init();
4848

49-
int GetLastDimSize(int tensor_index);
50-
5149
std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_;
52-
int bert_max_seq_len_;
50+
// The maximum input sequence length the BERT model can accept. Used for
51+
// static input tensors.
52+
int bert_max_seq_len_ = 2;
53+
// Whether the input tensors are dynamic instead of static.
54+
bool input_tensors_are_dynamic_ = false;
5355
};
5456

5557
} // namespace processor

0 commit comments

Comments
 (0)