Tensorflowのリポジトリから、TensorFlow Lite for Microcontrollers のサンプルコード「Magic Wand」からM5Stack Core2用のコードを生成します。M5Stack Core2内蔵のジャイロ加速度計「MPU6886」を使用します。開発環境は「PlatformIO」を使用します。
ここではM5Stack Core2用のコードの生成までで、ジェスチャは「WING(wを描く)」「RING(丸を描く)」「SLOPE(「L」を描くような感じで右斜め上から左下へ移動、そのあと右へ水平移動)」による動作確認は行っていません。
サンプルコード「Magic Wand」の取得
本家は「tflite-micro/tensorflow/lite/micro/examples/magic_wand/」になっており、次のように,makeコマンドによりesp-idf用のソースコード一式を生成し、M5Stackの開発環境「PlatformIO」へ持ってきやすいようにzipでまとめます。
$ make -f tensorflow/lite/micro/tools/make/Makefile TARGET=esp generate_magic_wand_esp_project $ cd tensorflow/lite/micro/tools/make/gen/esp_xtensa-esp32/prj/magic_wand/esp-idf/ $ ls -CF CMakeLists.txt LICENSE README_ESP.md components/ main/ $ zip -r esp32mw.zip components main
「components/tfmicro」フォルダを「PlatformIO」の「lib」配下へ、「main」フォルダの中身を「PlatformIO」の「src」配下へ移動します。
このままだとsetup()およびloop()がかぶるので、PlatformIOが自動生成した「main.cpp」、TensorFlow Lite for Microcontrollers 側の「main.cc」「main_functions.h」は削除し、「」main_functions.ccを「」main.cppへリネームします。また、「main.cpp」の中の「#include “main_functions.h”」は削除します。
今回使用するTensorFlow Lite for Microcontrollersのサンプルコード「Magic Wand」は、本家ではないのですが、「boochow/TFLite_Micro_MicroSpeech_M5Stack」から取得します。
プロジェクト「MagicWand」のファイル構成は以下のようになります。
サンプルコード「Magic Wand」の変更
取得したサンプルコード「Magic Wand」を次のように変更します。
- 7行目でコンパイルオプションとインクルードフォルダを追加します。
platformio.ini
[env:m5stack-core2] platform = espressif32 board = m5stack-core2 framework = arduino lib_deps = m5stack/M5Core2@^0.1.3 monitor_speed = 115200 build_flags = -DARDUINOSTL_M_H -Ilib/tfmicro/third_party/gemmlowp -Ilib/tfmicro/third_party/flatbuffers/include
「accelerometer_handler.cc」ではM5Stack Core2の内蔵のジャイロ加速度計「MPU6886」を使用するために、次のように変更します。
- 17-18行目で加速度センサ「MPU6886」を初期化します。
- 31行目で加速度情報を取得します。
- 33-36行目では次の調整を行っています。M5Stackのディスプレイを手前に向けて立てたときx, y, zが(0, 0, 1)、左へ90度傾けたとき(0, 1, 0)、ディスプレイを上に向けて机に置いた状態のとき(1, 0, 0)となるようにする必要があります。M5Stackの場合は、加速度センサのデータを(z, -x, -y)の順に並べることでこの条件を満たすことができます。
accelerometer_handler.cc
#include "accelerometer_handler.h"
#include "constants.h"
#include <M5Core2.h>
float accX = 0.0F; // Define variables for storing inertial sensor data
float accY = 0.0F;
float accZ = 0.0F;
float save_data[600] = {0.0};
int begin_index = 0;
bool pending_initial_data = true;
long last_sample_millis = 0;
TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter)
{
M5.IMU.Init(); // Init IMU sensor.
M5.IMU.SetAccelFsr(M5.IMU.AFS_2G);
error_reporter->Report("Magic starts!");
return kTfLiteOk;
}
static bool UpdateData()
{
bool new_data = false;
if ((millis() - last_sample_millis) < 40)
{
return false;
}
last_sample_millis = millis();
M5.IMU.getAccelData(&accX, &accY, &accZ); // Stores the triaxial accelerometer.
save_data[begin_index++] = 1000 * accZ;
save_data[begin_index++] = -1000 * accX;
save_data[begin_index++] = -1000 * accY;
if (begin_index >= 600)
{
begin_index = 0;
}
new_data = true;
return new_data;
}
bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
int length, bool reset_buffer)
{
if (reset_buffer)
{
memset(save_data, 0, 600 * sizeof(float));
begin_index = 0;
pending_initial_data = true;
}
if (!UpdateData())
{
return false;
}
if (pending_initial_data && begin_index >= 200)
{
pending_initial_data = false;
M5.Lcd.fillScreen(BLACK);
}
if (pending_initial_data)
{
return false;
}
for (int i = 0; i < length; ++i)
{
int ring_array_index = begin_index + i - length;
if (ring_array_index < 0)
{
ring_array_index += 600;
}
input[i] = save_data[ring_array_index];
}
return true;
}
- Tensorflow関係のインクルードファイルの後に、インクルードファイル「#include <M5Core2.h>」を追加します。位置を誤るとコンパイルエラーを引き起こします。
main.cpp
#include "accelerometer_handler.h"
#include "gesture_predictor.h"
#include "magic_wand_model_data.h"
#include "output_handler.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
#include <M5Core2.h>
// Globals, used for compatibility with Arduino-style sketches.
namespace
{
tflite::ErrorReporter *error_reporter = nullptr;
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *model_input = nullptr;
int input_length;
// Create an area of memory to use for input, output, and intermediate arrays.
// The size of this will depend on the model you're using, and may need to be
// determined by experimentation.
constexpr int kTensorArenaSize = 60 * 1024;
uint8_t tensor_arena[kTensorArenaSize];
// Whether we should clear the buffer next time we fetch data
bool should_clear_buffer = false;
} // namespace
char s[128];
// The name of this function is important for Arduino compatibility.
void setup()
{
M5.begin();
Serial.begin(115200);
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
static tflite::MicroErrorReporter micro_error_reporter; // NOLINT
error_reporter = µ_error_reporter;
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(g_magic_wand_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION)
{
error_reporter->Report(
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
// Pull in only the operation implementations we need.
// This relies on a complete list of all the ops needed by this graph.
// An easier approach is to just use the AllOpsResolver, but this will
// incur some penalty in code space for op implementations that are not
// needed by this graph.
static tflite::MicroMutableOpResolver micro_mutable_op_resolver; // NOLINT
micro_mutable_op_resolver.AddBuiltin(
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
tflite::ops::micro::Register_DEPTHWISE_CONV_2D());
micro_mutable_op_resolver.AddBuiltin(
tflite::BuiltinOperator_MAX_POOL_2D,
tflite::ops::micro::Register_MAX_POOL_2D());
micro_mutable_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
tflite::ops::micro::Register_CONV_2D());
micro_mutable_op_resolver.AddBuiltin(
tflite::BuiltinOperator_FULLY_CONNECTED,
tflite::ops::micro::Register_FULLY_CONNECTED());
micro_mutable_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
tflite::ops::micro::Register_SOFTMAX());
// Build an interpreter to run the model with
static tflite::MicroInterpreter static_interpreter(
model, micro_mutable_op_resolver, tensor_arena, kTensorArenaSize,
error_reporter);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors
interpreter->AllocateTensors();
// Obtain pointer to the model's input tensor
model_input = interpreter->input(0);
if ((model_input->dims->size != 4) || (model_input->dims->data[0] != 1) ||
(model_input->dims->data[1] != 128) ||
(model_input->dims->data[2] != kChannelNumber) ||
(model_input->type != kTfLiteFloat32))
{
error_reporter->Report("Bad input tensor parameters in model");
return;
}
input_length = model_input->bytes / sizeof(float);
// TfLiteStatus setup_status = SetupAccelerometer(error_reporter, M5.IMU);
TfLiteStatus setup_status = SetupAccelerometer(error_reporter);
if (setup_status != kTfLiteOk)
{
error_reporter->Report("Set up failed\n");
}
}
void loop()
{
bool got_data = ReadAccelerometer(error_reporter, model_input->data.f,
input_length, should_clear_buffer);
// Don't try to clear the buffer again
should_clear_buffer = false;
// If there was no new data, wait until next time
if (!got_data)
return;
// Run inference, and report any error
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk)
{
error_reporter->Report("Invoke failed on index: %d\n", begin_index);
return;
}
float *f = model_input->data.f;
float *p = interpreter->output(0)->data.f;
sprintf(s, "%+10.5f : %+10.5f : %+10.5f || W %10.5f : R %10.5f : S %10.5f",
f[381], f[382], f[383], p[0], p[1], p[2]);
error_reporter->Report(s);
// Analyze the results to obtain a prediction
int gesture_index = PredictGesture(interpreter->output(0)->data.f);
// Clear the buffer next time we read data
should_clear_buffer = gesture_index < 3;
// Produce an output
HandleOutput(error_reporter, gesture_index);
}
時間軸については、加速度計測のサンプリングレートが「constants.h」の中で次のように「25Hz」と定義されています。
const float kTargetHz = 25;
サンプルコード「Magic Wand」の実行
サンプルコード「Magic Wand」をビルドしてM5Stack Core2にアップロードし、シリアルモニタで確認すると、次のようなメッセージが表示されます。
加速度センサーに対応して学習したデータを実装すればエッジAIとして動作すると思われます。

