TensorFlow Liteでは、OpResolverクラス なるものがあります。
class OpResolver { virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; virtual TfLiteRegistration* FindOp(const char* op) const = 0; virtual void AddOp(tflite::BuiltinOperator op, TfLiteRegistration* registration) = 0; virtual void AddOp(const char* op, TfLiteRegistration* registration) = 0; };
tflite::ops::builtin::BuiltinOpResolver resolver;
と定義ファイルした後に、
resolver.AddOp("MY_CUSTOM_OP", Register_MY_CUSTOM_OP());
のように、AddOpメソッドにて、カスタムOpを追加できます。
ここでは、"MY_CUSTOM_OP"という名前の Op を Register_MY_CUSTOM_OP() にて登録します。
ここでは、"MY_CUSTOM_OP"という名前の Op を Register_MY_CUSTOM_OP() にて登録します。
これを利用しているのが、lite/experimental/micro/micro_mutable_op_resolver.h です。
class MicroMutableOpResolver : public OpResolver { public: const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, int version) const override; const TfLiteRegistration* FindOp(const char* op, int version) const override; void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, int min_version = 1, int max_version = 1); void AddCustom(const char* name, TfLiteRegistration* registration, int min_version = 1, int max_version = 1); private: TfLiteRegistration registrations_[TFLITE_REGISTRATIONS_MAX]; int registrations_len_ = 0; TF_LITE_REMOVE_VIRTUAL_DELETE };
テストコードを見てみると、
TF_LITE_MICRO_TEST(TestOperations) { using tflite::BuiltinOperator_CONV_2D; using tflite::BuiltinOperator_RELU; using tflite::MicroMutableOpResolver; using tflite::OpResolver; static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, tflite::MockPrepare, tflite::MockInvoke}; MicroMutableOpResolver micro_mutable_op_resolver; micro_mutable_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r, 0, 2); micro_mutable_op_resolver.AddCustom("mock_custom", &r, 0, 3);
ここで、AddBuiltinメソッドで BuiltinOperator_CONV_2D を AddCustomメソッドで mock_custom を追加しています。
OpResolver* resolver = µ_mutable_op_resolver; const TfLiteRegistration* registration = resolver->FindOp(BuiltinOperator_CONV_2D, 0);
ここで、BuildinOperatot_CONV_2Dが登録されているかをチェック。
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
registration を nullptr で初期化
TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); registration = resolver->FindOp(BuiltinOperator_CONV_2D, 10);
ここで、再び、BuildinOperatot_CONV_2Dが登録されているかをチェック。
TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
registration を nullptr で初期化
registration = resolver->FindOp(BuiltinOperator_RELU, 0);
BuiltinOperator_RELU が登録されているかをチェック?
TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
registration を nullptr で初期化
registration = resolver->FindOp("mock_custom", 0);
"mock_custom" が登録されているかをチェック?
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
registration を nullptr で初期化
TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); registration = resolver->FindOp("mock_custom", 10);
"mock_custom" が登録されているかをチェック?
TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
registration を nullptr で初期化
registration = resolver->FindOp("nonexistent_custom", 0);
"nonexistent_custom" が登録されているかをチェック?
TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); }
registration を nullptr で初期化