Tensorflowのリポジトリから、TensorFlow Lite for Microcontrollers のサンプルコード「Magic Wand」からM5Stack Core2用のコードを生成します。M5Stack Core2内蔵のジャイロ加速度計「MPU6886」を使用します。開発環境は「PlatformIO」を使用します。

ここではM5Stack Core2用のコードの生成までで、ジェスチャは「WING(wを描く)」「RING(丸を描く)」「SLOPE(「L」を描くような感じで右斜め上から左下へ移動、そのあと右へ水平移動)」による動作確認は行っていません。

サンプルコード「Magic Wand」の取得

本家は「tflite-micro/tensorflow/lite/micro/examples/magic_wand/」になっており、次のように,makeコマンドによりesp-idf用のソースコード一式を生成し、M5Stackの開発環境「PlatformIO」へ持ってきやすいようにzipでまとめます。

$ make -f tensorflow/lite/micro/tools/make/Makefile TARGET=esp generate_magic_wand_esp_project
$ cd tensorflow/lite/micro/tools/make/gen/esp_xtensa-esp32/prj/magic_wand/esp-idf/
$ ls -CF
CMakeLists.txt  LICENSE  README_ESP.md  components/  main/
$ zip -r esp32mw.zip components main

「components/tfmicro」フォルダを「PlatformIO」の「lib」配下へ、「main」フォルダの中身を「PlatformIO」の「src」配下へ移動します。

このままだとsetup()およびloop()がかぶるので、PlatformIOが自動生成した「main.cpp」、TensorFlow Lite for Microcontrollers 側の「main.cc」「main_functions.h」は削除し、「」main_functions.ccを「」main.cppへリネームします。また、「main.cpp」の中の「#include “main_functions.h”」は削除します。

今回使用するTensorFlow Lite for Microcontrollersのサンプルコード「Magic Wand」は、本家ではないのですが、「boochow/TFLite_Micro_MicroSpeech_M5Stack」から取得します。

プロジェクト「MagicWand」のファイル構成は以下のようになります。

サンプルコード「Magic Wand」の変更

取得したサンプルコード「Magic Wand」を次のように変更します。

  • 7行目でコンパイルオプションとインクルードフォルダを追加します。

platformio.ini

[env:m5stack-core2]
platform = espressif32
board = m5stack-core2
framework = arduino
lib_deps = m5stack/M5Core2@^0.1.3
monitor_speed = 115200
build_flags = -DARDUINOSTL_M_H -Ilib/tfmicro/third_party/gemmlowp -Ilib/tfmicro/third_party/flatbuffers/include

「accelerometer_handler.cc」ではM5Stack Core2の内蔵のジャイロ加速度計「MPU6886」を使用するために、次のように変更します。

  • 17-18行目で加速度センサ「MPU6886」を初期化します。
  • 31行目で加速度情報を取得します。
  • 33-36行目では次の調整を行っています。M5Stackのディスプレイを手前に向けて立てたときx, y, zが(0, 0, 1)、左へ90度傾けたとき(0, 1, 0)、ディスプレイを上に向けて机に置いた状態のとき(1, 0, 0)となるようにする必要があります。M5Stackの場合は、加速度センサのデータを(z, -x, -y)の順に並べることでこの条件を満たすことができます。

accelerometer_handler.cc

#include "accelerometer_handler.h"

#include "constants.h"
#include <M5Core2.h>

float accX = 0.0F; // Define variables for storing inertial sensor data
float accY = 0.0F;
float accZ = 0.0F;

float save_data[600] = {0.0};
int begin_index = 0;
bool pending_initial_data = true;
long last_sample_millis = 0;

TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter)
{
  M5.IMU.Init(); // Init IMU sensor.
  M5.IMU.SetAccelFsr(M5.IMU.AFS_2G);
  error_reporter->Report("Magic starts!");
  return kTfLiteOk;
}

static bool UpdateData()
{
  bool new_data = false;
  if ((millis() - last_sample_millis) < 40)
  {
    return false;
  }
  last_sample_millis = millis();
  M5.IMU.getAccelData(&accX, &accY, &accZ); // Stores the triaxial accelerometer.

  save_data[begin_index++] = 1000 * accZ;
  save_data[begin_index++] = -1000 * accX;
  save_data[begin_index++] = -1000 * accY;

  if (begin_index >= 600)
  {
    begin_index = 0;
  }
  new_data = true;

  return new_data;
}

bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
                       int length, bool reset_buffer)
{
  if (reset_buffer)
  {
    memset(save_data, 0, 600 * sizeof(float));
    begin_index = 0;
    pending_initial_data = true;
  }

  if (!UpdateData())
  {
    return false;
  }

  if (pending_initial_data && begin_index >= 200)
  {
    pending_initial_data = false;
    M5.Lcd.fillScreen(BLACK);
  }

  if (pending_initial_data)
  {
    return false;
  }

  for (int i = 0; i < length; ++i)
  {
    int ring_array_index = begin_index + i - length;
    if (ring_array_index < 0)
    {
      ring_array_index += 600;
    }
    input[i] = save_data[ring_array_index];
  }
  return true;
}
  • Tensorflow関係のインクルードファイルの後に、インクルードファイル「#include <M5Core2.h>」を追加します。位置を誤るとコンパイルエラーを引き起こします。

main.cpp

#include "accelerometer_handler.h"
#include "gesture_predictor.h"
#include "magic_wand_model_data.h"
#include "output_handler.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"

#include <M5Core2.h>

// Globals, used for compatibility with Arduino-style sketches.
namespace
{
  tflite::ErrorReporter *error_reporter = nullptr;
  const tflite::Model *model = nullptr;
  tflite::MicroInterpreter *interpreter = nullptr;
  TfLiteTensor *model_input = nullptr;
  int input_length;

  // Create an area of memory to use for input, output, and intermediate arrays.
  // The size of this will depend on the model you're using, and may need to be
  // determined by experimentation.
  constexpr int kTensorArenaSize = 60 * 1024;
  uint8_t tensor_arena[kTensorArenaSize];

  // Whether we should clear the buffer next time we fetch data
  bool should_clear_buffer = false;
} // namespace

char s[128];

// The name of this function is important for Arduino compatibility.
void setup()
{
  M5.begin();
  Serial.begin(115200);

  // Set up logging. Google style is to avoid globals or statics because of
  // lifetime uncertainty, but since this has a trivial destructor it's okay.
  static tflite::MicroErrorReporter micro_error_reporter; // NOLINT
  error_reporter = &micro_error_reporter;

  // Map the model into a usable data structure. This doesn't involve any
  // copying or parsing, it's a very lightweight operation.
  model = tflite::GetModel(g_magic_wand_model_data);
  if (model->version() != TFLITE_SCHEMA_VERSION)
  {
    error_reporter->Report(
        "Model provided is schema version %d not equal "
        "to supported version %d.",
        model->version(), TFLITE_SCHEMA_VERSION);
    return;
  }

  // Pull in only the operation implementations we need.
  // This relies on a complete list of all the ops needed by this graph.
  // An easier approach is to just use the AllOpsResolver, but this will
  // incur some penalty in code space for op implementations that are not
  // needed by this graph.
  static tflite::MicroMutableOpResolver micro_mutable_op_resolver; // NOLINT
  micro_mutable_op_resolver.AddBuiltin(
      tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
      tflite::ops::micro::Register_DEPTHWISE_CONV_2D());
  micro_mutable_op_resolver.AddBuiltin(
      tflite::BuiltinOperator_MAX_POOL_2D,
      tflite::ops::micro::Register_MAX_POOL_2D());
  micro_mutable_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
                                       tflite::ops::micro::Register_CONV_2D());
  micro_mutable_op_resolver.AddBuiltin(
      tflite::BuiltinOperator_FULLY_CONNECTED,
      tflite::ops::micro::Register_FULLY_CONNECTED());
  micro_mutable_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
                                       tflite::ops::micro::Register_SOFTMAX());

  // Build an interpreter to run the model with
  static tflite::MicroInterpreter static_interpreter(
      model, micro_mutable_op_resolver, tensor_arena, kTensorArenaSize,
      error_reporter);
  interpreter = &static_interpreter;

  // Allocate memory from the tensor_arena for the model's tensors
  interpreter->AllocateTensors();

  // Obtain pointer to the model's input tensor
  model_input = interpreter->input(0);
  if ((model_input->dims->size != 4) || (model_input->dims->data[0] != 1) ||
      (model_input->dims->data[1] != 128) ||
      (model_input->dims->data[2] != kChannelNumber) ||
      (model_input->type != kTfLiteFloat32))
  {
    error_reporter->Report("Bad input tensor parameters in model");
    return;
  }
  input_length = model_input->bytes / sizeof(float);

  // TfLiteStatus setup_status = SetupAccelerometer(error_reporter, M5.IMU);
  TfLiteStatus setup_status = SetupAccelerometer(error_reporter);
  if (setup_status != kTfLiteOk)
  {
    error_reporter->Report("Set up failed\n");
  }
}

void loop()
{
  bool got_data = ReadAccelerometer(error_reporter, model_input->data.f,
                                    input_length, should_clear_buffer);
  // Don't try to clear the buffer again
  should_clear_buffer = false;
  // If there was no new data, wait until next time
  if (!got_data)
    return;
  // Run inference, and report any error
  TfLiteStatus invoke_status = interpreter->Invoke();
  if (invoke_status != kTfLiteOk)
  {
    error_reporter->Report("Invoke failed on index: %d\n", begin_index);
    return;
  }

  float *f = model_input->data.f;
  float *p = interpreter->output(0)->data.f;
  sprintf(s, "%+10.5f : %+10.5f : %+10.5f || W %10.5f : R %10.5f : S %10.5f",
          f[381], f[382], f[383], p[0], p[1], p[2]);
  error_reporter->Report(s);

  // Analyze the results to obtain a prediction
  int gesture_index = PredictGesture(interpreter->output(0)->data.f);
  // Clear the buffer next time we read data
  should_clear_buffer = gesture_index < 3;
  // Produce an output
  HandleOutput(error_reporter, gesture_index);
}

時間軸については、加速度計測のサンプリングレートが「constants.h」の中で次のように「25Hz」と定義されています。

const float kTargetHz = 25;

サンプルコード「Magic Wand」の実行

サンプルコード「Magic Wand」をビルドしてM5Stack Core2にアップロードし、シリアルモニタで確認すると、次のようなメッセージが表示されます。

加速度センサーに対応して学習したデータを実装すればエッジAIとして動作すると思われます。