今週と来週は、TensorFlow Liteで攻めちゃう? とりあえず、r1.13 ベースで。
今日は、InterpreterBuilder。。。
とりあえず、テストプログラムから。。。
// This makes sure the ErrorReporter is marshalled from FlatBufferModel to // the Interpreter. TEST(BasicFlatBufferModel, TestNullErrorReporter) { auto model = FlatBufferModel::BuildFromFile( "tensorflow/lite/testdata/empty_model.bin", nullptr); ASSERT_TRUE(model); std::unique_ptr<Interpreter> interpreter; TrivialResolver resolver; InterpreterBuilder(*model, resolver)(&interpreter); ASSERT_NE(interpreter->Invoke(), kTfLiteOk); }
InterpreterBuilder は、第一引数に FlatBufferModel、第二引数に OpResolver を取り、
Operator() メソッドにて、std::unique_ptr<Interpreter> のインスタンスへのポインタをわたしています。
Operator() メソッドにて、std::unique_ptr<Interpreter> のインスタンスへのポインタをわたしています。
InterpreterBuilder::Operator()メソッドは、
TfLiteStatus InterpreterBuilder::operator()( std::unique_ptr<Interpreter>* interpreter) { return operator()(interpreter, /*num_threads=*/-1); }
num_threads を -1 にして、InterpreterBuilder::operator() メソッドを呼び出しています。
TfLiteStatus InterpreterBuilder::operator()( std::unique_ptr<Interpreter>* interpreter, int num_threads) { if (!interpreter) { error_reporter_->Report( "Null output pointer passed to InterpreterBuilder."); return kTfLiteError; } // Safe exit by deleting partially created interpreter, to reduce verbosity // on error conditions. Use by return cleanup_on_error(); auto cleanup_and_error = [&interpreter]() { interpreter->reset(); return kTfLiteError; }; if (!model_) { error_reporter_->Report("Null pointer passed in as model."); return cleanup_and_error(); } if (model_->version() != TFLITE_SCHEMA_VERSION) { error_reporter_->Report( "Model provided is schema version %d not equal " "to supported version %d.\n", model_->version(), TFLITE_SCHEMA_VERSION); return cleanup_and_error(); } if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) { error_reporter_->Report("Registration failed.\n"); return cleanup_and_error(); }
え、ここまでで引数のチェック。。。
// Flatbuffer model schemas define a list of opcodes independent of the graph. // We first map those to registrations. This reduces string lookups for custom // ops since we only do it once per custom op rather than once per custom op // invocation in the model graph. // Construct interpreter with correct number of tensors and operators. auto* subgraphs = model_->subgraphs();
ここで、モデル内のSubgraphを抽出
auto* buffers = model_->buffers(); if (subgraphs->size() != 1) { error_reporter_->Report("Only 1 subgraph is currently supported.\n"); return cleanup_and_error(); }
現時点では、Subgraph は、ひとつのみ有効。
const tflite::SubGraph* subgraph = (*subgraphs)[0]; auto operators = subgraph->operators(); auto tensors = subgraph->tensors(); if (!operators || !tensors || !buffers) { error_reporter_->Report( "Did not get operators, tensors, or buffers in input flat buffer.\n"); return cleanup_and_error(); }
Subgraph から operators と tensors と buffers を獲得して、すべてがあるかどうかチェック?
何故か? buffers は、Subgraphのサイズを調べる前に獲得している。
何故か? buffers は、Subgraphのサイズを調べる前に獲得している。
interpreter->reset(new Interpreter(error_reporter_)); if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) { return cleanup_and_error(); }
ここでやっと、Intepreter を生成し、tensorsを追加。。。
// Set num threads (**interpreter).SetNumThreads(num_threads); // Parse inputs/outputs (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
生成して Interpreter に、スレッド数( -1 )、Subgraphの入力および出力を設定。
// Finally setup nodes and tensors if (ParseNodes(operators, interpreter->get()) != kTfLiteOk) return cleanup_and_error(); if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) return cleanup_and_error();
Subgraph内のノードとバッファをパースする
std::vector<int> variables; for (int i = 0; i < (*interpreter)->tensors_size(); ++i) { auto* tensor = (*interpreter)->tensor(i); if (tensor->is_variable) { variables.push_back(i); } } (**interpreter).SetVariables(std::move(variables));
tensor から variable を抽出して、設定。
if (ApplyDelegates(interpreter->get()) != kTfLiteOk) return cleanup_and_error();
最後に、delegate を設定。
return kTfLiteOk; }
TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter) { // TODO(b/117561550): Move flex delegate application to the OpResolver. if (AcquireFlexDelegate == nullptr) { return kTfLiteOk; }
AcquireFlexDelegate が設定されていないときは、OK。
bool has_flex_op = false; for (const auto* registration : flatbuffer_op_index_to_registration_) { if ((registration->builtin_code == BuiltinOperator_CUSTOM) && IsFlexOp(registration->custom_name)) { has_flex_op = true; break; } } if (!has_flex_op) { return kTfLiteOk; }
has_flex_op が false であれば、OK。
if (auto flex_delegate = AcquireFlexDelegate()) { return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate)); } return kTfLiteOk; }
AcquireFlexDelegate() の戻り値を使って、InterpreterのModifyGraphWithDelegateメソッドの戻り値。
AcquireFlexDelegate() は、
#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__)) // Using weak symbols for the flex delegate allows automatic injection of the // delegate simply by adding it as a dependency. See also the strong override in // lite/delegates/flex/delegate.cc. __attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); } #else Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr; #endifのようになっていて、delegates/flex/delegate.cc で、
// Corresponding weak declaration found in lite/model.cc. std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> AcquireFlexDelegate() { return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>( tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) { delete reinterpret_cast<tflite::FlexDelegate*>(delegate); }); }のように、FlexDelegateクラスのインスタンスを生成している。
delegates/flex/delegate.cc は、lite/BUILD の config_setting にて、with_select_tf_ops が true の時のみ使われる。
config_setting( name = "with_select_tf_ops", define_values = {"with_select_tf_ops": "true"}, visibility = ["//visibility:public"], )
] + select({ ":with_select_tf_ops": [ "//tensorflow/lite/delegates/flex:delegate", ], "//conditions:default": [], }),