TensorFlow Liteでは、どうやら、Interpreter::AddNodeWithParameters メソッドなるものがノードを追加するようです。
例えば、FlexModelTest::AddTfLiteMulOpメソッド
void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs, const std::vector<int>& outputs) { static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; reg.builtin_code = BuiltinOperator_MUL; reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { auto* i0 = &context->tensors[node->inputs->data[0]]; auto* o = &context->tensors[node->outputs->data[0]]; return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims)); }; reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { auto* i0 = &context->tensors[node->inputs->data[0]]; auto* i1 = &context->tensors[node->inputs->data[1]]; auto* o = &context->tensors[node->outputs->data[0]]; for (int i = 0; i < o->bytes / sizeof(float); ++i) { o->data.f[i] = i0->data.f[i] * i1->data.f[i]; } return kTfLiteOk; }; CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0, nullptr, ®), kTfLiteOk); }
入力(inputs)、出力(outputs)と、TfLiteRegistration (reg) を引数に、AddNodeWithParametersを呼び出しています。
生成されるのノードは、BuiltinOperator_MUL Op のノードになるようです。
生成されるのノードは、BuiltinOperator_MUL Op のノードになるようです。
reg.prepare と reg.invoke にラムダ式で処理を設定していますね。
reg.prepare にて、入力のdimsに出力を合わせて、reg.invoke にて、2つの入力ベクタの乗算を出力ベクタに。
reg.prepare にて、入力のdimsに出力を合わせて、reg.invoke にて、2つの入力ベクタの乗算を出力ベクタに。
もう一つ。FlexModelTest::AddTfOpメソッド。
void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs, const std::vector<int>& outputs) { auto attr = [](const string& key, const string& value) { return " attr{ key: '" + key + "' value {" + value + "}}"; }; string type_attribute; switch (interpreter_->tensor(inputs[0])->type) { case kTfLiteInt32: type_attribute = attr("T", "type: DT_INT32"); break; case kTfLiteFloat32: type_attribute = attr("T", "type: DT_FLOAT"); break; case kTfLiteString: type_attribute = attr("T", "type: DT_STRING"); break; default: // TODO(b/113613439): Use nodedef string utilities to properly handle all // types. LOG(FATAL) << "Type not supported"; break; } if (op == kUnpack) { string attributes = type_attribute + attr("num", "i: 2") + attr("axis", "i: 0"); AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs); } else if (op == kIdentity) { string attributes = type_attribute; AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs); } else if (op == kAdd) { string attributes = type_attribute; AddTfOp("FlexAdd", "Add", attributes, inputs, outputs); } else if (op == kMul) { string attributes = type_attribute; AddTfOp("FlexMul", "Mul", attributes, inputs, outputs); } else if (op == kNonExistent) { AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); } else if (op == kIncompatibleNodeDef) { // "Cast" op is created without attributes - making it incompatible. AddTfOp("FlexCast", "Cast", "", inputs, outputs); } }
ここでは、TensorFlowのOpを "FlexXXX"に変えて、下記のFlexModelTest::AddTfOpメソッドにて、AddNodeWithParameters経由でノードを追加してます。
void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name, const string& nodedef_str, const std::vector<int>& inputs, const std::vector<int>& outputs) { static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; reg.builtin_code = BuiltinOperator_CUSTOM; reg.custom_name = tflite_name; tensorflow::NodeDef nodedef; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( nodedef_str + " op: '" + tf_name + "'", &nodedef)); string serialized_nodedef; CHECK(nodedef.SerializeToString(&serialized_nodedef)); flexbuffers::Builder fbb; fbb.Vector([&]() { fbb.String(nodedef.op()); fbb.String(serialized_nodedef); }); fbb.Finish(); flexbuffers_.push_back(fbb.GetBuffer()); auto& buffer = flexbuffers_.back(); CHECK_EQ(interpreter_->AddNodeWithParameters( inputs, outputs, reinterpret_cast<const char*>(buffer.data()), buffer.size(), nullptr, ®), kTfLiteOk); }
const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op, int version) const { // Return the NULL Op for all ops whose name start with "Flex", allowing // the interpreter to delegate their execution. if (IsFlexOp(op)) { static TfLiteRegistration null_op{ nullptr, nullptr, &UnsupportedTensorFlowOp, nullptr, nullptr, BuiltinOperator_CUSTOM, "Flex", 1}; return &null_op; } return MutableOpResolver::FindOp(op, version); }