This project uses the TensorFlow Object Detection API to train models suitable for the Google Coral Edge TPU. Follow the steps below to install the required programs and to train your own models for use on the Edge TPU.
Follow these installation steps.
$ git clone
Check the requirements.txt
file to ensure you have the necessary Python packages installed on your system or virtual environment.
Under the annotations
, images
, tf-record
and tflite-models
directories in edge-tpu-train
place a sub-directory named after your data set(s).
Follow these steps to install labelImg, a great tool you can use to label your images that you'll use for training.
Place your data set images in the images/<named-data-set>
directory you created in the step above.
About 200 images per class is sufficient to re-train most models in my experience.
Use labelImg to label the images you collected. Store the xml annotation files in annotations/<named-data-set>
Classes need to be listed in the label map. Since in the case I am detecting the members of my family (including pets) the label map looks like this:
item {
id: 1
name: 'lindo'
item {
id: 2
name: 'nikki'
item {
id: 3
name: 'eva'
item {
id: 4
name: 'nico'
item {
id: 5
name: 'polly'
item {
id: 6
name: 'rebel'
item {
id: 7
name: 'unknown'
Note that id 0 is reserved. Store this file in the annotations/<named-data-set>
folder with the name label_map.pbtxt
TFRecord is an important data format designed for Tensorflow. (Read more about it here). Before you can train your custom object detector, you must convert your data into the TFRecord format.
Since you need to train as well as validate your model, the data set will be split into training (train.record
) and validation sets (val.record
). The purpose of training set is straightforward - it is the set of examples the model learns from. The validation set is a set of examples used DURING TRAINING to iteratively assess model accuracy.
Use the program to convert the data set into train.record and val.record.
This program is preconfigured to do 80–20 train-val split. Execute it by running:
$ python3 ./ --dataset_name <named-data-set>
As configured above the program will store the .record
files to the tf_record/<named-data-set>
There are many pre-trained object detection models available in the model zoo but you need to limit your selection to those that can be converted to quantized TensorFlow Lite (object detection) models. (You must use quantization-aware training, so the model must be designed with fake quantization nodes.)
In order to train them using your custom data set, the models need to be restored in Tensorflow using their checkpoints (.ckpt
files), which are records of previous model states.
For this example download ssd_mobilenet_v2_quantized_coco
from here and save its model checkpoint files (model.ckpt.meta
, model.ckpt.index
) to the checkpoints
If required (for example you are changing the number of classes from 7 used in this example to something else) modify the files in the config/<named-data-set>
directory as needed. There should not be many changes required if using the scripts above as directed except for the name of your data set.
Follow the steps below to re-train the model replacing the values for pipline_config_path
and num_training_steps
as needed. I found 1400 training steps to be sufficient in this example.
$ \
--pipeline_config_path ./configs/pipeline_mobilenet_v2_ssd_retrain_last_few_layers.config \
--num_training_steps 1400
Start tensorboard in a new terminal:
$ tensorboard --logdir ./train
Run the following script to export the model to a frozen graph, convert it to a TF Lite model and compile it to run on the edge TPU. Replace the pipeline configuration path as required and make sure the checkpoint number matches the last training step used in training the model.
NB: this assumes the Edge TPU Compiler has been installed on your system.
$ \
--pipeline_config_path ./configs/pipeline_mobilenet_v2_ssd_retrain_last_few_layers.config \
--checkpoint_num 1400
You can now use the retrained and compiled model with the Edge TPU Python API.