Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

TensorFlow Lite の OpResolver


TensorFlow Liteでは、OpResolverクラス なるものがあります。

ドキュメント、Customizing the kernel libraryにあるように、この OpResolver クラスを使って、カーネルライブラリをカスタマイズできます。

class OpResolver {
  virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0;
  virtual TfLiteRegistration* FindOp(const char* op) const = 0;
  virtual void AddOp(tflite::BuiltinOperator op, TfLiteRegistration* registration) = 0;
  virtual void AddOp(const char* op, TfLiteRegistration* registration) = 0;
};

    tflite::ops::builtin::BuiltinOpResolver resolver;

と定義ファイルした後に、

    resolver.AddOp("MY_CUSTOM_OP", Register_MY_CUSTOM_OP());

のように、AddOpメソッドにて、カスタムOpを追加できます。
ここでは、"MY_CUSTOM_OP"という名前の Op を Register_MY_CUSTOM_OP() にて登録します。

これを利用しているのが、lite/experimental/micro/micro_mutable_op_resolver.h です。

class MicroMutableOpResolver : public OpResolver {
 public:
  const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
                                   int version) const override;
  const TfLiteRegistration* FindOp(const char* op, int version) const override;
  void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
                  int min_version = 1, int max_version = 1);
  void AddCustom(const char* name, TfLiteRegistration* registration,
                 int min_version = 1, int max_version = 1);

 private:
  TfLiteRegistration registrations_[TFLITE_REGISTRATIONS_MAX];
  int registrations_len_ = 0;

  TF_LITE_REMOVE_VIRTUAL_DELETE
};

テストコードを見てみると、
TF_LITE_MICRO_TEST(TestOperations) {
  using tflite::BuiltinOperator_CONV_2D;
  using tflite::BuiltinOperator_RELU;
  using tflite::MicroMutableOpResolver;
  using tflite::OpResolver;

  static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
                                 tflite::MockPrepare, tflite::MockInvoke};

  MicroMutableOpResolver micro_mutable_op_resolver;
  micro_mutable_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r, 0, 2);
  micro_mutable_op_resolver.AddCustom("mock_custom", &r, 0, 3);

ここで、AddBuiltinメソッドで BuiltinOperator_CONV_2D を AddCustomメソッドで mock_custom を追加しています。

  OpResolver* resolver = µ_mutable_op_resolver;

  const TfLiteRegistration* registration =
      resolver->FindOp(BuiltinOperator_CONV_2D, 0);

ここで、BuildinOperatot_CONV_2Dが登録されているかをチェック。

  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);

registration を nullptr で初期化

  TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));

  registration = resolver->FindOp(BuiltinOperator_CONV_2D, 10);

ここで、再び、BuildinOperatot_CONV_2Dが登録されているかをチェック。

  TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);

registration を nullptr で初期化

  registration = resolver->FindOp(BuiltinOperator_RELU, 0);

BuiltinOperator_RELU が登録されているかをチェック?

  TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);

registration を nullptr で初期化

  registration = resolver->FindOp("mock_custom", 0);

"mock_custom" が登録されているかをチェック?

  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);

registration を nullptr で初期化

  TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));

  registration = resolver->FindOp("mock_custom", 10);

"mock_custom" が登録されているかをチェック?

  TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);

registration を nullptr で初期化

  registration = resolver->FindOp("nonexistent_custom", 0);

"nonexistent_custom" が登録されているかをチェック?

  TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
}

registration を nullptr で初期化