Skip to content

Commit 91c5436

Browse files
khanhlvgcopybara-github
authored andcommitted
Added TFLite Raspberry Pi (Python) image segmentation sample app.
PiperOrigin-RevId: 414929525
1 parent 040b327 commit 91c5436

7 files changed

Lines changed: 364 additions & 0 deletions

File tree

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# TensorFlow Lite Python image segmentation example with Raspberry Pi.
2+
3+
This example uses [TensorFlow Lite](https://tensorflow.org/lite) with Python on
4+
a Raspberry Pi to perform real-time image segmentation using images streamed
5+
from the camera.
6+
7+
At the end of this page, there are extra steps to accelerate the example using
8+
the Coral USB Accelerator, which increases the inference speed by ~10x.
9+
10+
## Set up your hardware
11+
12+
Before you begin, you need to
13+
[set up your Raspberry Pi](https://projects.raspberrypi.org/en/projects/raspberry-pi-setting-up)
14+
with Raspberry Pi OS (preferably updated to Buster).
15+
16+
You also need to
17+
[connect and configure the Pi Camera](https://www.raspberrypi.org/documentation/configuration/camera.md)
18+
if you use the Pi Camera. This code also works with USB camera connect to the
19+
Raspberry Pi.
20+
21+
And to see the results from the camera, you need a monitor connected to the
22+
Raspberry Pi. It's okay if you're using SSH to access the Pi shell (you don't
23+
need to use a keyboard connected to the Pi)—you only need a monitor attached to
24+
the Pi to see the camera stream.
25+
26+
## Install the TensorFlow Lite runtime
27+
28+
In this project, all you need from the TensorFlow Lite API is the `Interpreter`
29+
class. So instead of installing the large `tensorflow` package, we're using the
30+
much smaller `tflite_runtime` package.
31+
32+
To install this on your Raspberry Pi, follow the instructions in the
33+
[Python quickstart](https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python).
34+
35+
You can install the TFLite runtime using this script.
36+
37+
```
38+
sh setup.sh
39+
```
40+
41+
## Download the example files
42+
43+
First, clone this Git repo onto your Raspberry Pi like this:
44+
45+
```
46+
git clone https://github.com/tensorflow/examples --depth 1
47+
```
48+
49+
Then use our script to install a couple Python packages, and download the
50+
`Deeplabv3` model:
51+
52+
```
53+
cd examples/lite/examples/image_segmentation/raspberry_pi
54+
55+
# The script install the required dependencies and download the TFLite models.
56+
sh setup.sh
57+
```
58+
59+
## Run the example
60+
61+
```
62+
python3 segment.py
63+
```
64+
65+
* You can optionally specify the `model` parameter to set the TensorFlow Lite
66+
model to be used:
67+
* The default value is `deeplabv3.tflite`
68+
* Image segmentation models from TensorFlow Hub **with metadata** are
69+
supported.
70+
* You can optionally specify the `displayMode` parameter to change how the
71+
segmentation result is displayed:
72+
* Use values: `overlay`, `side-by-side`.
73+
* The default value is `overlay`.
74+
* Example usage:
75+
76+
```
77+
python3 main.py
78+
--model somemodel.tflite
79+
--displayMode side-by-side
80+
```
81+
82+
**Overlay mode** ![Overlay Image](overlay_mode.png)
83+
84+
**Side-by-side mode** ![Side-by-side Image](sidebyside_mode.png)
85+
86+
For more information about executing inferences with TensorFlow Lite, read
87+
[TensorFlow Lite inference](https://www.tensorflow.org/lite/guide/inference).
88+
89+
## Speed up model inference (optional)
90+
91+
If you want to significantly speed up the inference time, you can attach an
92+
[Coral USB Accelerator](https://coral.withgoogle.com/products/accelerator)—a USB
93+
accessory that adds the
94+
[Edge TPU ML accelerator](https://coral.withgoogle.com/docs/edgetpu/faq/) to any
95+
Linux-based system.
96+
97+
If you have a Coral USB Accelerator, you can run the sample with it enabled:
98+
99+
1. First, be sure you have completed the
100+
[USB Accelerator setup instructions](https://coral.withgoogle.com/docs/accelerator/get-started/).
101+
102+
2. Run the image segmentation script using the EdgeTPU TFLite model and enable
103+
the EdgeTPU option.
104+
105+
```
106+
python3 main.py \
107+
--enableEdgeTPU
108+
--model deeplabv3_edgetpu.tflite
109+
```
110+
111+
You should see significantly faster inference speeds.
112+
113+
For more information about creating and running TensorFlow Lite models with
114+
Coral devices, read
115+
[TensorFlow models on the Edge TPU](https://coral.withgoogle.com/docs/edgetpu/models-intro/).
305 KB
Loading
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
argparse
2+
numpy>=1.20.0
3+
opencv-python~=4.5.3.56
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
--extra-index-url https://google-coral.github.io/py-repo/
2+
tflite-runtime==2.5.0.post1
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Main script to run image segmentation."""
15+
16+
import argparse
17+
import sys
18+
import time
19+
from typing import List
20+
21+
import cv2
22+
from image_segmenter import ColoredLabel
23+
from image_segmenter import ImageSegmenter
24+
from image_segmenter import ImageSegmenterOptions
25+
import numpy as np
26+
import utils
27+
28+
# Visualization parameters
29+
_FPS_AVERAGE_FRAME_COUNT = 10
30+
_FPS_LEFT_MARGIN = 24 # pixels
31+
_LEGEND_TEXT_COLOR = (0, 0, 255) # red
32+
_LEGEND_BACKGROUND_COLOR = (255, 255, 255) # white
33+
_LEGEND_FONT_SIZE = 1
34+
_LEGEND_FONT_THICKNESS = 1
35+
_LEGEND_ROW_SIZE = 20 # pixels
36+
_LEGEND_RECT_SIZE = 16 # pixels
37+
_LABEL_MARGIN = 10
38+
_OVERLAY_ALPHA = 0.5
39+
_PADDING_WIDTH_FOR_LEGEND = 150 # pixels
40+
41+
42+
def run(model: str, display_mode: str, num_threads: int, enable_edgetpu: bool,
43+
camera_id: int, width: int, height: int) -> None:
44+
"""Continuously run inference on images acquired from the camera.
45+
46+
Args:
47+
model: Name of the TFLite image segmentation model.
48+
display_mode: Name of mode to display image segmentation.
49+
num_threads: Number of CPU threads to run the model.
50+
enable_edgetpu: Whether to run the model on EdgeTPU.
51+
camera_id: The camera id to be passed to OpenCV.
52+
width: The width of the frame captured from the camera.
53+
height: The height of the frame captured from the camera.
54+
"""
55+
56+
# Initialize the image segmentation model.
57+
options = ImageSegmenterOptions(
58+
num_threads=num_threads, enable_edgetpu=enable_edgetpu)
59+
segmenter = ImageSegmenter(model_path=model, options=options)
60+
61+
# Variables to calculate FPS
62+
counter, fps = 0, 0
63+
start_time = time.time()
64+
65+
# Start capturing video input from the camera
66+
cap = cv2.VideoCapture(camera_id)
67+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
68+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
69+
70+
# Continuously capture images from the camera and run inference.
71+
while cap.isOpened():
72+
success, image = cap.read()
73+
if not success:
74+
sys.exit(
75+
'ERROR: Unable to read from webcam. Please verify your webcam settings.'
76+
)
77+
78+
counter += 1
79+
image = cv2.flip(image, 1)
80+
81+
# Segment with each frame from camera.
82+
segmentation_result = segmenter.segment(image)
83+
84+
# Convert the segmentation result into an image.
85+
seg_map_img, found_colored_labels = utils.segmentation_map_to_image(
86+
segmentation_result)
87+
88+
# Resize the segmentation mask to be the same shape as input image.
89+
seg_map_img = cv2.resize(
90+
seg_map_img,
91+
dsize=(image.shape[1], image.shape[0]),
92+
interpolation=cv2.INTER_NEAREST)
93+
94+
# Visualize segmentation result on image.
95+
overlay = visualize(image, seg_map_img, display_mode, fps,
96+
found_colored_labels)
97+
98+
# Calculate the FPS
99+
if counter % _FPS_AVERAGE_FRAME_COUNT == 0:
100+
end_time = time.time()
101+
fps = _FPS_AVERAGE_FRAME_COUNT / (end_time - start_time)
102+
start_time = time.time()
103+
104+
# Stop the program if the ESC key is pressed.
105+
if cv2.waitKey(1) == 27:
106+
break
107+
cv2.imshow('image_segmentation', overlay)
108+
109+
cap.release()
110+
cv2.destroyAllWindows()
111+
112+
113+
def visualize(input_image: np.ndarray, segmentation_map_image: np.ndarray,
114+
display_mode: str, fps: float,
115+
colored_labels: List[ColoredLabel]) -> np.ndarray:
116+
"""Visualize segmentation result on image.
117+
118+
Args:
119+
input_image: The [height, width, 3] RGB input image.
120+
segmentation_map_image: The [height, width, 3] RGB segmentation map image.
121+
display_mode: How the segmentation map should be shown. 'overlay' or
122+
'side-by-side'.
123+
fps: Value of fps.
124+
colored_labels: List of colored labels found in the segmentation result.
125+
126+
Returns:
127+
Input image overlaid with segmentation result.
128+
"""
129+
# Show the input image and the segmentation map image.
130+
if display_mode == 'overlay':
131+
# Overlay mode.
132+
overlay = cv2.addWeighted(input_image, _OVERLAY_ALPHA,
133+
segmentation_map_image, _OVERLAY_ALPHA, 0)
134+
elif display_mode == 'side-by-side':
135+
# Side by side mode.
136+
overlay = cv2.hconcat([input_image, segmentation_map_image])
137+
else:
138+
sys.exit(f'ERROR: Unsupported display mode: {display_mode}.')
139+
140+
# Show the FPS
141+
fps_text = 'FPS = ' + str(int(fps))
142+
text_location = (_FPS_LEFT_MARGIN, _LEGEND_ROW_SIZE)
143+
cv2.putText(overlay, fps_text, text_location, cv2.FONT_HERSHEY_PLAIN,
144+
_LEGEND_FONT_SIZE, _LEGEND_TEXT_COLOR, _LEGEND_FONT_THICKNESS)
145+
146+
# Initialize the origin coordinates of the label.
147+
legend_x = overlay.shape[1] + _LABEL_MARGIN
148+
legend_y = overlay.shape[0] // _LEGEND_ROW_SIZE + _LABEL_MARGIN
149+
150+
# Expand the frame to show the label.
151+
overlay = cv2.copyMakeBorder(overlay, 0, 0, 0, _PADDING_WIDTH_FOR_LEGEND,
152+
cv2.BORDER_CONSTANT, None,
153+
_LEGEND_BACKGROUND_COLOR)
154+
155+
# Show the label on right-side frame.
156+
for colored_label in colored_labels:
157+
rect_color = colored_label.color
158+
start_point = (legend_x, legend_y)
159+
end_point = (legend_x + _LEGEND_RECT_SIZE, legend_y + _LEGEND_RECT_SIZE)
160+
cv2.rectangle(overlay, start_point, end_point, rect_color,
161+
-_LEGEND_FONT_THICKNESS)
162+
163+
label_location = legend_x + _LEGEND_RECT_SIZE + _LABEL_MARGIN, legend_y + _LABEL_MARGIN
164+
cv2.putText(overlay, colored_label.label, label_location,
165+
cv2.FONT_HERSHEY_PLAIN, _LEGEND_FONT_SIZE, _LEGEND_TEXT_COLOR,
166+
_LEGEND_FONT_THICKNESS)
167+
legend_y += (_LEGEND_RECT_SIZE + _LABEL_MARGIN)
168+
169+
return overlay
170+
171+
172+
def main():
173+
parser = argparse.ArgumentParser(
174+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
175+
parser.add_argument(
176+
'--model',
177+
help='Name of image segmentation model.',
178+
required=False,
179+
default='deeplabv3.tflite')
180+
parser.add_argument(
181+
'--displayMode',
182+
help='Mode to display image segmentation.',
183+
required=False,
184+
default='overlay')
185+
parser.add_argument(
186+
'--numThreads',
187+
help='Number of CPU threads to run the model.',
188+
required=False,
189+
default=4)
190+
parser.add_argument(
191+
'--enableEdgeTPU',
192+
help='Whether to run the model on EdgeTPU.',
193+
action='store_true',
194+
required=False,
195+
default=False)
196+
parser.add_argument(
197+
'--cameraId', help='Id of camera.', required=False, default=0)
198+
parser.add_argument(
199+
'--frameWidth',
200+
help='Width of frame to capture from camera.',
201+
required=False,
202+
default=640)
203+
parser.add_argument(
204+
'--frameHeight',
205+
help='Height of frame to capture from camera.',
206+
required=False,
207+
default=480)
208+
args = parser.parse_args()
209+
210+
run(args.model, args.displayMode, int(args.numThreads),
211+
bool(args.enableEdgeTPU), int(args.cameraId), args.frameWidth,
212+
args.frameHeight)
213+
214+
215+
if __name__ == '__main__':
216+
main()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/bin/bash
2+
3+
if [ $# -eq 0 ]; then
4+
DATA_DIR="./"
5+
else
6+
DATA_DIR="$1"
7+
fi
8+
9+
# Install Python dependencies
10+
python3 -m pip install -r requirements_pypi.txt
11+
python3 -m pip install -r requirements_tflite.txt
12+
13+
# Download TF Lite models with metadata.
14+
FILE=${DATA_DIR}/deeplabv3.tflite
15+
if [ ! -f "$FILE" ]; then
16+
curl \
17+
-L 'https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2?lite-format=tflite' \
18+
-o ${FILE}
19+
fi
20+
21+
FILE=${DATA_DIR}/deeplabv3_edgetpu.tflite
22+
if [ ! -f "$FILE" ]; then
23+
curl \
24+
-L 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/edgetpu/deeplabv3_mnv2_dm05_pascal_quant_edgetpu.tflite' \
25+
-o ${FILE}
26+
fi
27+
28+
echo -e "Downloaded files are in ${DATA_DIR}"
353 KB
Loading

0 commit comments

Comments
 (0)