TensorFlow Liteでは、どうやら、Interpreter::AddNodeWithParameters メソッドなるものがノードを追加するようです。
例えば、FlexModelTest::AddTfLiteMulOpメソッド
void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs, const std::vector<int>& outputs) { static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; reg.builtin_code = BuiltinOperator_MUL; reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { auto* i0 = &context->tensors[node->inputs->data[0]]; auto* o = &context->tensors[node->outputs->data[0]]; return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims)); }; reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { auto* i0 = &context->tensors[node->inputs->data[0]]; auto* i1 = &context->tensors[node->inputs->data[1]]; auto* o = &context->tensors[node->outputs->data[0]]; for (int i = 0; i < o->bytes / sizeof(float); ++i) { o->data.f[i] = i0->data.f[i] * i1->data.f[i]; } return kTfLiteOk; }; CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0, nullptr, ®), kTfLiteOk); }
入力(inputs)、出力(outputs)と、TfLiteRegistration (reg) を引数に、AddNodeWithParametersを呼び出しています。
生成されるのノードは、BuiltinOperator_MUL Op のノードになるようです。
生成されるのノードは、BuiltinOperator_MUL Op のノードになるようです。
reg.prepare と reg.invoke にラムダ式で処理を設定していますね。
reg.prepare にて、入力のdimsに出力を合わせて、reg.invoke にて、2つの入力ベクタの乗算を出力ベクタに。
reg.prepare にて、入力のdimsに出力を合わせて、reg.invoke にて、2つの入力ベクタの乗算を出力ベクタに。
もう一つ。FlexModelTest::AddTfOpメソッド。
void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs, const std::vector<int>& outputs) { auto attr = [](const string& key, const string& value) { return " attr{ key: '" + key + "' value {" + value + "}}"; }; string type_attribute; switch (interpreter_->tensor(inputs[0])->type) { case kTfLiteInt32: type_attribute = attr("T", "type: DT_INT32"); break; case kTfLiteFloat32: type_attribute = attr("T", "type: DT_FLOAT"); break; case kTfLiteString: type_attribute = attr("T", "type: DT_STRING"); break; default: // TODO(b/113613439): Use nodedef string utilities to properly handle all // types. LOG(FATAL) << "Type not supported"; break; } if (op == kUnpack) { string attributes = type_attribute + attr("num", "i: 2") + attr("axis", "i: 0"); AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs); } else if (op == kIdentity) { string attributes = type_attribute; AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs); } else if (op == kAdd) { string attributes = type_attribute; AddTfOp("FlexAdd", "Add", attributes, inputs, outputs); } else if (op == kMul) { string attributes = type_attribute; AddTfOp("FlexMul", "Mul", attributes, inputs, outputs); } else if (op == kNonExistent) { AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); } else if (op == kIncompatibleNodeDef) { // "Cast" op is created without attributes - making it incompatible. AddTfOp("FlexCast", "Cast", "", inputs, outputs); } }
ここでは、TensorFlowのOpを "FlexXXX"に変えて、下記のFlexModelTest::AddTfOpメソッドにて、AddNodeWithParameters経由でノードを追加してます。
void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
const string& nodedef_str,
const std::vector<int>& inputs,
const std::vector<int>& outputs) {
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
reg.builtin_code = BuiltinOperator_CUSTOM;
reg.custom_name = tflite_name;
tensorflow::NodeDef nodedef;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
nodedef_str + " op: '" + tf_name + "'", &nodedef));
string serialized_nodedef;
CHECK(nodedef.SerializeToString(&serialized_nodedef));
flexbuffers::Builder fbb;
fbb.Vector([&]() {
fbb.String(nodedef.op());
fbb.String(serialized_nodedef);
});
fbb.Finish();
flexbuffers_.push_back(fbb.GetBuffer());
auto& buffer = flexbuffers_.back();
CHECK_EQ(interpreter_->AddNodeWithParameters(
inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
buffer.size(), nullptr, ®),
kTfLiteOk);
}
const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op,
int version) const {
// Return the NULL Op for all ops whose name start with "Flex", allowing
// the interpreter to delegate their execution.
if (IsFlexOp(op)) {
static TfLiteRegistration null_op{
nullptr, nullptr, &UnsupportedTensorFlowOp,
nullptr, nullptr, BuiltinOperator_CUSTOM,
"Flex", 1};
return &null_op;
}
return MutableOpResolver::FindOp(op, version);
}