Vengineerの妄想(準備期間)

人生は短いけど、長いです。人生を楽しみましょう!

TensorFlow LiteのInterpreterBuilder


今週と来週は、TensorFlow Liteで攻めちゃう? とりあえず、r1.13 ベースで。

今日は、InterpreterBuilder。。。

とりあえず、テストプログラムから。。。


// This makes sure the ErrorReporter is marshalled from FlatBufferModel to
// the Interpreter.
TEST(BasicFlatBufferModel, TestNullErrorReporter) {
  auto model = FlatBufferModel::BuildFromFile(
      "tensorflow/lite/testdata/empty_model.bin", nullptr);
  ASSERT_TRUE(model);

  std::unique_ptr<Interpreter> interpreter;
  TrivialResolver resolver;
  InterpreterBuilder(*model, resolver)(&interpreter);
  ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
}

InterpreterBuilder は、第一引数に FlatBufferModel、第二引数に OpResolver を取り、
Operator() メソッドにて、std::unique_ptr<Interpreter> のインスタンスへのポインタをわたしています。

InterpreterBuilder::Operator()メソッドは、
TfLiteStatus InterpreterBuilder::operator()(
    std::unique_ptr<Interpreter>* interpreter) {
  return operator()(interpreter, /*num_threads=*/-1);
}

num_threads を -1 にして、InterpreterBuilder::operator() メソッドを呼び出しています。

TfLiteStatus InterpreterBuilder::operator()(
    std::unique_ptr<Interpreter>* interpreter, int num_threads) {
  if (!interpreter) {
    error_reporter_->Report(
        "Null output pointer passed to InterpreterBuilder.");
    return kTfLiteError;
  }

  // Safe exit by deleting partially created interpreter, to reduce verbosity
  // on error conditions. Use by return cleanup_on_error();
  auto cleanup_and_error = [&interpreter]() {
    interpreter->reset();
    return kTfLiteError;
  };

  if (!model_) {
    error_reporter_->Report("Null pointer passed in as model.");
    return cleanup_and_error();
  }

  if (model_->version() != TFLITE_SCHEMA_VERSION) {
    error_reporter_->Report(
        "Model provided is schema version %d not equal "
        "to supported version %d.\n",
        model_->version(), TFLITE_SCHEMA_VERSION);
    return cleanup_and_error();
  }

  if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
    error_reporter_->Report("Registration failed.\n");
    return cleanup_and_error();
  }

え、ここまでで引数のチェック。。。

  // Flatbuffer model schemas define a list of opcodes independent of the graph.
  // We first map those to registrations. This reduces string lookups for custom
  // ops since we only do it once per custom op rather than once per custom op
  // invocation in the model graph.
  // Construct interpreter with correct number of tensors and operators.
  auto* subgraphs = model_->subgraphs();

ここで、モデル内のSubgraphを抽出

  auto* buffers = model_->buffers();
  if (subgraphs->size() != 1) {
    error_reporter_->Report("Only 1 subgraph is currently supported.\n");
    return cleanup_and_error();
  }

現時点では、Subgraph は、ひとつのみ有効。

  const tflite::SubGraph* subgraph = (*subgraphs)[0];
  auto operators = subgraph->operators();
  auto tensors = subgraph->tensors();
  if (!operators || !tensors || !buffers) {
    error_reporter_->Report(
        "Did not get operators, tensors, or buffers in input flat buffer.\n");
    return cleanup_and_error();
  }

Subgraph から operators と tensors と buffers を獲得して、すべてがあるかどうかチェック?
何故か? buffers は、Subgraphのサイズを調べる前に獲得している。

  interpreter->reset(new Interpreter(error_reporter_));
  if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
    return cleanup_and_error();
  }

ここでやっと、Intepreter を生成し、tensorsを追加。。。

  // Set num threads
  (**interpreter).SetNumThreads(num_threads);
  // Parse inputs/outputs
  (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
  (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));

生成して Interpreter に、スレッド数( -1 )、Subgraphの入力および出力を設定。

  // Finally setup nodes and tensors
  if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
    return cleanup_and_error();
  if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
    return cleanup_and_error();

Subgraph内のノードとバッファをパースする

  std::vector<int> variables;
  for (int i = 0; i < (*interpreter)->tensors_size(); ++i) {
    auto* tensor = (*interpreter)->tensor(i);
    if (tensor->is_variable) {
      variables.push_back(i);
    }
  }
  (**interpreter).SetVariables(std::move(variables));

tensor から variable を抽出して、設定。

  if (ApplyDelegates(interpreter->get()) != kTfLiteOk)
    return cleanup_and_error();

最後に、delegate を設定。

  return kTfLiteOk;
}

Interpreter が InterpreterBuilder 経由で生成されるのは、 最後のApplyDelgatesの部分のため?

TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter) {
  // TODO(b/117561550): Move flex delegate application to the OpResolver.
  if (AcquireFlexDelegate == nullptr) {
    return kTfLiteOk;
  }

AcquireFlexDelegate が設定されていないときは、OK。

  bool has_flex_op = false;
  for (const auto* registration : flatbuffer_op_index_to_registration_) {
    if ((registration->builtin_code == BuiltinOperator_CUSTOM) &&
        IsFlexOp(registration->custom_name)) {
      has_flex_op = true;
      break;
    }
  }

  if (!has_flex_op) {
    return kTfLiteOk;
  }

has_flex_op が false であれば、OK。

  if (auto flex_delegate = AcquireFlexDelegate()) {
    return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate));
  }
  return kTfLiteOk;
}

AcquireFlexDelegate() の戻り値を使って、InterpreterのModifyGraphWithDelegateメソッドの戻り値。

AcquireFlexDelegate() は、
#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__))
// Using weak symbols for the flex delegate allows automatic injection of the
// delegate simply by adding it as a dependency. See also the strong override in
// lite/delegates/flex/delegate.cc.
__attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
  return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
}
#else
Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr;
#endif
のようになっていて、delegates/flex/delegate.cc で、

// Corresponding weak declaration found in lite/model.cc.
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>
AcquireFlexDelegate() {
  return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
      tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
        delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
      });
}
のように、FlexDelegateクラスのインスタンスを生成している。
delegates/flex/delegate.cc は、lite/BUILD の config_setting にて、with_select_tf_ops が true の時のみ使われる。
config_setting(
    name = "with_select_tf_ops",
    define_values = {"with_select_tf_ops": "true"},
    visibility = ["//visibility:public"],
)
    ] + select({
        ":with_select_tf_ops": [
            "//tensorflow/lite/delegates/flex:delegate",
        ],
        "//conditions:default": [],
    }),