Tensorflowのリポジトリから、TensorFlow Lite for Microcontrollers のサンプルコード「Magic Wand」からM5Stack Core2用のコードを生成します。M5Stack Core2内蔵のジャイロ加速度計「MPU6886」を使用します。開発環境は「PlatformIO」を使用します。
ここではM5Stack Core2用のコードの生成までで、ジェスチャは「WING(wを描く)」「RING(丸を描く)」「SLOPE(「L」を描くような感じで右斜め上から左下へ移動、そのあと右へ水平移動)」による動作確認は行っていません。
サンプルコード「Magic Wand」の取得
本家は「tflite-micro/tensorflow/lite/micro/examples/magic_wand/」になっており、次のように,makeコマンドによりesp-idf用のソースコード一式を生成し、M5Stackの開発環境「PlatformIO」へ持ってきやすいようにzipでまとめます。
$ make -f tensorflow/lite/micro/tools/make/Makefile TARGET=esp generate_magic_wand_esp_project $ cd tensorflow/lite/micro/tools/make/gen/esp_xtensa-esp32/prj/magic_wand/esp-idf/ $ ls -CF CMakeLists.txt LICENSE README_ESP.md components/ main/ $ zip -r esp32mw.zip components main
「components/tfmicro」フォルダを「PlatformIO」の「lib」配下へ、「main」フォルダの中身を「PlatformIO」の「src」配下へ移動します。
このままだとsetup()およびloop()がかぶるので、PlatformIOが自動生成した「main.cpp」、TensorFlow Lite for Microcontrollers 側の「main.cc」「main_functions.h」は削除し、「」main_functions.ccを「」main.cppへリネームします。また、「main.cpp」の中の「#include “main_functions.h”」は削除します。
今回使用するTensorFlow Lite for Microcontrollersのサンプルコード「Magic Wand」は、本家ではないのですが、「boochow/TFLite_Micro_MicroSpeech_M5Stack」から取得します。
プロジェクト「MagicWand」のファイル構成は以下のようになります。
サンプルコード「Magic Wand」の変更
取得したサンプルコード「Magic Wand」を次のように変更します。
- 7行目でコンパイルオプションとインクルードフォルダを追加します。
platformio.ini
[env:m5stack-core2] platform = espressif32 board = m5stack-core2 framework = arduino lib_deps = m5stack/M5Core2@^0.1.3 monitor_speed = 115200 build_flags = -DARDUINOSTL_M_H -Ilib/tfmicro/third_party/gemmlowp -Ilib/tfmicro/third_party/flatbuffers/include
「accelerometer_handler.cc」ではM5Stack Core2の内蔵のジャイロ加速度計「MPU6886」を使用するために、次のように変更します。
- 17-18行目で加速度センサ「MPU6886」を初期化します。
- 31行目で加速度情報を取得します。
- 33-36行目では次の調整を行っています。M5Stackのディスプレイを手前に向けて立てたときx, y, zが(0, 0, 1)、左へ90度傾けたとき(0, 1, 0)、ディスプレイを上に向けて机に置いた状態のとき(1, 0, 0)となるようにする必要があります。M5Stackの場合は、加速度センサのデータを(z, -x, -y)の順に並べることでこの条件を満たすことができます。
accelerometer_handler.cc
#include "accelerometer_handler.h" #include "constants.h" #include <M5Core2.h> float accX = 0.0F; // Define variables for storing inertial sensor data float accY = 0.0F; float accZ = 0.0F; float save_data[600] = {0.0}; int begin_index = 0; bool pending_initial_data = true; long last_sample_millis = 0; TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter) { M5.IMU.Init(); // Init IMU sensor. M5.IMU.SetAccelFsr(M5.IMU.AFS_2G); error_reporter->Report("Magic starts!"); return kTfLiteOk; } static bool UpdateData() { bool new_data = false; if ((millis() - last_sample_millis) < 40) { return false; } last_sample_millis = millis(); M5.IMU.getAccelData(&accX, &accY, &accZ); // Stores the triaxial accelerometer. save_data[begin_index++] = 1000 * accZ; save_data[begin_index++] = -1000 * accX; save_data[begin_index++] = -1000 * accY; if (begin_index >= 600) { begin_index = 0; } new_data = true; return new_data; } bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input, int length, bool reset_buffer) { if (reset_buffer) { memset(save_data, 0, 600 * sizeof(float)); begin_index = 0; pending_initial_data = true; } if (!UpdateData()) { return false; } if (pending_initial_data && begin_index >= 200) { pending_initial_data = false; M5.Lcd.fillScreen(BLACK); } if (pending_initial_data) { return false; } for (int i = 0; i < length; ++i) { int ring_array_index = begin_index + i - length; if (ring_array_index < 0) { ring_array_index += 600; } input[i] = save_data[ring_array_index]; } return true; }
- Tensorflow関係のインクルードファイルの後に、インクルードファイル「#include <M5Core2.h>」を追加します。位置を誤るとコンパイルエラーを引き起こします。
main.cpp
#include "accelerometer_handler.h" #include "gesture_predictor.h" #include "magic_wand_model_data.h" #include "output_handler.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/micro_error_reporter.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" #include <M5Core2.h> // Globals, used for compatibility with Arduino-style sketches. namespace { tflite::ErrorReporter *error_reporter = nullptr; const tflite::Model *model = nullptr; tflite::MicroInterpreter *interpreter = nullptr; TfLiteTensor *model_input = nullptr; int input_length; // Create an area of memory to use for input, output, and intermediate arrays. // The size of this will depend on the model you're using, and may need to be // determined by experimentation. constexpr int kTensorArenaSize = 60 * 1024; uint8_t tensor_arena[kTensorArenaSize]; // Whether we should clear the buffer next time we fetch data bool should_clear_buffer = false; } // namespace char s[128]; // The name of this function is important for Arduino compatibility. void setup() { M5.begin(); Serial.begin(115200); // Set up logging. Google style is to avoid globals or statics because of // lifetime uncertainty, but since this has a trivial destructor it's okay. static tflite::MicroErrorReporter micro_error_reporter; // NOLINT error_reporter = µ_error_reporter; // Map the model into a usable data structure. This doesn't involve any // copying or parsing, it's a very lightweight operation. model = tflite::GetModel(g_magic_wand_model_data); if (model->version() != TFLITE_SCHEMA_VERSION) { error_reporter->Report( "Model provided is schema version %d not equal " "to supported version %d.", model->version(), TFLITE_SCHEMA_VERSION); return; } // Pull in only the operation implementations we need. // This relies on a complete list of all the ops needed by this graph. // An easier approach is to just use the AllOpsResolver, but this will // incur some penalty in code space for op implementations that are not // needed by this graph. static tflite::MicroMutableOpResolver micro_mutable_op_resolver; // NOLINT micro_mutable_op_resolver.AddBuiltin( tflite::BuiltinOperator_DEPTHWISE_CONV_2D, tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); micro_mutable_op_resolver.AddBuiltin( tflite::BuiltinOperator_MAX_POOL_2D, tflite::ops::micro::Register_MAX_POOL_2D()); micro_mutable_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, tflite::ops::micro::Register_CONV_2D()); micro_mutable_op_resolver.AddBuiltin( tflite::BuiltinOperator_FULLY_CONNECTED, tflite::ops::micro::Register_FULLY_CONNECTED()); micro_mutable_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, tflite::ops::micro::Register_SOFTMAX()); // Build an interpreter to run the model with static tflite::MicroInterpreter static_interpreter( model, micro_mutable_op_resolver, tensor_arena, kTensorArenaSize, error_reporter); interpreter = &static_interpreter; // Allocate memory from the tensor_arena for the model's tensors interpreter->AllocateTensors(); // Obtain pointer to the model's input tensor model_input = interpreter->input(0); if ((model_input->dims->size != 4) || (model_input->dims->data[0] != 1) || (model_input->dims->data[1] != 128) || (model_input->dims->data[2] != kChannelNumber) || (model_input->type != kTfLiteFloat32)) { error_reporter->Report("Bad input tensor parameters in model"); return; } input_length = model_input->bytes / sizeof(float); // TfLiteStatus setup_status = SetupAccelerometer(error_reporter, M5.IMU); TfLiteStatus setup_status = SetupAccelerometer(error_reporter); if (setup_status != kTfLiteOk) { error_reporter->Report("Set up failed\n"); } } void loop() { bool got_data = ReadAccelerometer(error_reporter, model_input->data.f, input_length, should_clear_buffer); // Don't try to clear the buffer again should_clear_buffer = false; // If there was no new data, wait until next time if (!got_data) return; // Run inference, and report any error TfLiteStatus invoke_status = interpreter->Invoke(); if (invoke_status != kTfLiteOk) { error_reporter->Report("Invoke failed on index: %d\n", begin_index); return; } float *f = model_input->data.f; float *p = interpreter->output(0)->data.f; sprintf(s, "%+10.5f : %+10.5f : %+10.5f || W %10.5f : R %10.5f : S %10.5f", f[381], f[382], f[383], p[0], p[1], p[2]); error_reporter->Report(s); // Analyze the results to obtain a prediction int gesture_index = PredictGesture(interpreter->output(0)->data.f); // Clear the buffer next time we read data should_clear_buffer = gesture_index < 3; // Produce an output HandleOutput(error_reporter, gesture_index); }
時間軸については、加速度計測のサンプリングレートが「constants.h」の中で次のように「25Hz」と定義されています。
const float kTargetHz = 25;
サンプルコード「Magic Wand」の実行
サンプルコード「Magic Wand」をビルドしてM5Stack Core2にアップロードし、シリアルモニタで確認すると、次のようなメッセージが表示されます。
加速度センサーに対応して学習したデータを実装すればエッジAIとして動作すると思われます。