Vengineerの妄想

人生を妄想しています。

TensorFlow Lite の Interpreter::AddNodeWithParameters


TensorFlow Liteでは、どうやら、Interpreter::AddNodeWithParameters メソッドなるものがノードを追加するようです。

例えば、FlexModelTest::AddTfLiteMulOpメソッド

void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
                                   const std::vector<int>& outputs) {
  static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
  reg.builtin_code = BuiltinOperator_MUL;
  reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
    auto* i0 = &context->tensors[node->inputs->data[0]];
    auto* o = &context->tensors[node->outputs->data[0]];
    return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
  };
  reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
    auto* i0 = &context->tensors[node->inputs->data[0]];
    auto* i1 = &context->tensors[node->inputs->data[1]];
    auto* o = &context->tensors[node->outputs->data[0]];
    for (int i = 0; i < o->bytes / sizeof(float); ++i) {
      o->data.f[i] = i0->data.f[i] * i1->data.f[i];
    }
    return kTfLiteOk;
  };

  CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
                                               nullptr, ®),
           kTfLiteOk);
}

入力(inputs)、出力(outputs)と、TfLiteRegistration (reg) を引数に、AddNodeWithParametersを呼び出しています。
生成されるのノードは、BuiltinOperator_MUL Op のノードになるようです。

reg.prepare と reg.invokeラムダ式で処理を設定していますね。
reg.prepare にて、入力のdimsに出力を合わせて、reg.invoke にて、2つの入力ベクタの乗算を出力ベクタに。

もう一つ。FlexModelTest::AddTfOpメソッド。

void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
                            const std::vector<int>& outputs) {
  auto attr = [](const string& key, const string& value) {
    return " attr{ key: '" + key + "' value {" + value + "}}";
  };

  string type_attribute;
  switch (interpreter_->tensor(inputs[0])->type) {
    case kTfLiteInt32:
      type_attribute = attr("T", "type: DT_INT32");
      break;
    case kTfLiteFloat32:
      type_attribute = attr("T", "type: DT_FLOAT");
      break;
    case kTfLiteString:
      type_attribute = attr("T", "type: DT_STRING");
      break;
    default:
      // TODO(b/113613439): Use nodedef string utilities to properly handle all
      // types.
      LOG(FATAL) << "Type not supported";
      break;
  }

  if (op == kUnpack) {
    string attributes =
        type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
    AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
  } else if (op == kIdentity) {
    string attributes = type_attribute;
    AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
  } else if (op == kAdd) {
    string attributes = type_attribute;
    AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
  } else if (op == kMul) {
    string attributes = type_attribute;
    AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
  } else if (op == kNonExistent) {
    AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
  } else if (op == kIncompatibleNodeDef) {
    // "Cast" op is created without attributes - making it incompatible.
    AddTfOp("FlexCast", "Cast", "", inputs, outputs);
  }
}

ここでは、TensorFlowのOpを "FlexXXX"に変えて、下記のFlexModelTest::AddTfOpメソッドにて、AddNodeWithParameters経由でノードを追加してます。

void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
                            const string& nodedef_str,
                            const std::vector<int>& inputs,
                            const std::vector<int>& outputs) {
  static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
  reg.builtin_code = BuiltinOperator_CUSTOM;
  reg.custom_name = tflite_name;

  tensorflow::NodeDef nodedef;
  CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
      nodedef_str + " op: '" + tf_name + "'", &nodedef));
  string serialized_nodedef;
  CHECK(nodedef.SerializeToString(&serialized_nodedef));
  flexbuffers::Builder fbb;
  fbb.Vector([&]() {
    fbb.String(nodedef.op());
    fbb.String(serialized_nodedef);
  });
  fbb.Finish();

  flexbuffers_.push_back(fbb.GetBuffer());
  auto& buffer = flexbuffers_.back();
  CHECK_EQ(interpreter_->AddNodeWithParameters(
               inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
               buffer.size(), nullptr, ®),
           kTfLiteOk);
}

なお、Opの名前の最初の文字列が "Flex" の場合は、Eagerモードのため、GraphモードでのTensorFlow Liteでは、
下記のように、Op が NULL Op になってします。

const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op,
                                                    int version) const {
  // Return the NULL Op for all ops whose name start with "Flex", allowing
  // the interpreter to delegate their execution.
  if (IsFlexOp(op)) {
    static TfLiteRegistration null_op{
        nullptr, nullptr, &UnsupportedTensorFlowOp,
        nullptr, nullptr, BuiltinOperator_CUSTOM,
        "Flex",  1};
    return &null_op;
  }
  return MutableOpResolver::FindOp(op, version);
}