Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add INT32 support to SUB #3037

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 66 additions & 28 deletions tensorflow/lite/micro/kernels/sub.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,39 +36,76 @@ void* SubInit(TfLiteContext* context, const char* buffer, size_t length) {
return context->AllocatePersistentBuffer(context, sizeof(OpDataSub));
}

void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
const OpDataSub* data, const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
tflite::ArithmeticParams op_params;
SetActivationParams(output_activation_min, output_activation_max, &op_params);
if (data->requires_broadcast) {
tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
} else {
tflite::reference_ops::SubWithActivation(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
TfLiteStatus EvalSub(TfLiteContext* context, TfLiteNode* node,
TfLiteSubParams* params, const OpDataSub* data,
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
switch (output->type) {
case kTfLiteFloat32: {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
tflite::ArithmeticParams op_params;
SetActivationParams(output_activation_min, output_activation_max,
&op_params);
if (data->requires_broadcast) {
tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
} else {
tflite::reference_ops::SubWithActivation(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
}
} break;
case kTfLiteInt32: {
int32_t output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
tflite::ArithmeticParams op_params;
SetActivationParams(output_activation_min, output_activation_max,
&op_params);
if (data->requires_broadcast) {
tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int32_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int32_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int32_t>(output));
} else {
tflite::reference_ops::SubWithActivation(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int32_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int32_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int32_t>(output));
}
} break;
default:
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(output->type), output->type);
return kTfLiteError;
}

return kTfLiteOk;
}

TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteSubParams* params, const OpDataSub* data,
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output) {
tflite::ArithmeticParams op_params;
tflite::ArithmeticParams op_params = {};
op_params.left_shift = data->left_shift;
op_params.input1_offset = data->input1_offset;
op_params.input1_multiplier = data->input1_multiplier;
Expand Down Expand Up @@ -147,8 +184,9 @@ TfLiteStatus SubEval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
const OpDataSub& data = *(static_cast<const OpDataSub*>(node->user_data));

if (output->type == kTfLiteFloat32) {
EvalSub(context, node, params, &data, input1, input2, output);
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
TF_LITE_ENSURE_OK(
context, EvalSub(context, node, params, &data, input1, input2, output));
} else if (output->type == kTfLiteInt8 || output->type == kTfLiteInt16) {
TF_LITE_ENSURE_OK(context, EvalSubQuantized(context, node, params, &data,
input1, input2, output));
Expand Down
12 changes: 11 additions & 1 deletion tensorflow/lite/micro/kernels/sub_common.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -98,6 +98,16 @@ TfLiteStatus SubPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(
CalculateOpDataSub(context, params, input1, input2, output, data));

if (output->type == kTfLiteInt32) {
// Only support INT32 unquantized SUB for now.
TF_LITE_ENSURE_EQ(context, input1->quantization.type,
kTfLiteNoQuantization);
TF_LITE_ENSURE_EQ(context, input2->quantization.type,
kTfLiteNoQuantization);
TF_LITE_ENSURE_EQ(context, output->quantization.type,
kTfLiteNoQuantization);
}

micro_context->DeallocateTempTfLiteTensor(input1);
micro_context->DeallocateTempTfLiteTensor(input2);
micro_context->DeallocateTempTfLiteTensor(output);
Expand Down
37 changes: 36 additions & 1 deletion tensorflow/lite/micro/kernels/sub_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -105,6 +105,27 @@ void TestSubFloat(int* input1_dims_data, const float* input1_data,
ElementCount(*output_dims), activation);
}

void TestSubInt32(int* input1_dims_data, const int32_t* input1_data,
int* input2_dims_data, const int32_t* input2_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need #if !defined(XTENSA) around this method also to prevent the unused function warning (which is promoted to an error).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done @ddavis-2015 .

int* output_dims_data, const int32_t* expected_output,
TfLiteFusedActivation activation, int32_t* output_data) {
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);

constexpr int inputs_size = 2;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateTensor(input1_data, input1_dims),
CreateTensor(input2_data, input2_dims),
CreateTensor(output_data, output_dims),
};

ValidateSubGoldens(tensors, tensors_size, expected_output, output_data,
ElementCount(*output_dims), activation);
}

template <typename T>
void TestSubQuantized(int* input1_dims_data, const float* input1_data,
T* input1_quantized, float input1_scale,
Expand Down Expand Up @@ -219,6 +240,20 @@ TF_LITE_MICRO_TEST(FloatSubWithScalarBroadcast) {
}
}

#if !defined(XTENSA)
TF_LITE_MICRO_TEST(Int32SubNoActivation) {
int inout_shape[] = {4, 1, 2, 2, 1};
const int32_t input1_values[] = {-2, 2147483646, -1, 1146622854};
const int32_t input2_values[] = {3, 1, -2147483647, -726978367};
const int32_t golden_values[] = {-5, 2147483645, 2147483646, 1873601221};
const int kOutputDimsCount = 4;
int32_t output_data[kOutputDimsCount];
tflite::testing::TestSubInt32(inout_shape, input1_values, inout_shape,
input2_values, inout_shape, golden_values,
kTfLiteActNone, output_data);
}
#endif

TF_LITE_MICRO_TEST(QuantizedSubNoActivationInt8) {
const float scales[] = {0.25, 0.5, 1.0};
const int zero_points[] = {-10, 4, 13};
Expand Down
Loading