昨日のTensorFlow LiteのInterpreterBuilderの最後に出てきた AcquireFlexDelegateメソッド から FlexDelegateクラスについて。
// Corresponding weak declaration found in lite/model.cc. std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> AcquireFlexDelegate() { return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>( tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) { delete reinterpret_cast<tflite::FlexDelegate*>(delegate); }); } std::unique_ptr<FlexDelegate> FlexDelegate::Create() { std::unique_ptr<flex::DelegateData> delegate_data; if (!flex::DelegateData::Create(&delegate_data).ok()) { fprintf(stderr, "Unable to initialize TensorFlow context.\n"); return nullptr; } return std::unique_ptr<FlexDelegate>( new FlexDelegate(std::move(delegate_data))); } FlexDelegate::FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data) : TfLiteDelegate(TfLiteDelegateCreate()), delegate_data_(std::move(delegate_data)) { data_ = delegate_data_.get(); Prepare = &flex::delegate::Prepare; CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle; flags = kTfLiteDelegateFlagsAllowDynamicTensors; }
にて、FlexDelegateクラスのインスタンスを生成。
FlexDelgateクラス
// WARNING: This is an experimental interface that is subject to change. // Delegate that can be used to extract parts of a graph that are designed to be // executed by TensorFlow's runtime via Eager.
// // The interpreter must be constructed after the FlexDelegate and destructed // before the FlexDelegate. This delegate may be used with multiple // interpreters, but it is *not* thread-safe. // // Usage: // auto delegate = FlexDelegate::Create(); // ... build interpreter ... // // if (delegate) { // interpreter->ModifyGraphWithDelegate( // delegate.get(), /*allow_dynamic_tensors=*/true); // }
interpreter->ModifyGraphWithDelegate にて、グラフを変換しているのか?
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { return primary_subgraph().ModifyGraphWithDelegate(delegate); }
で、
Subgraph& primary_subgraph() { return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry. }
なので、Subgraph の ModifyGraphWithDelegate メソッドを呼んでいるようです。
TfLiteStatus Subgraph::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) { int last_execution_plan_index_prepared; TF_LITE_ENSURE_OK(&context_, PrepareOpsStartingAt( 0, &last_execution_plan_index_prepared)); if (has_dynamic_tensors_) { ReportError( "Attempting to use a delegate that only supports static-sized " "tensors with a graph that has dynamic-sized tensors."); return kTfLiteError; } }
// TODO(aselle): Consider if it is worth storing pointers to delegates. // Setup additional context interface. SwitchToDelegateContext(); TfLiteStatus status = delegate->Prepare(context_, delegate); // Remove additional context info. SwitchToKernelContext();
ここで、Context を Delegate に切り替える、Prepare を実行するようです。
TF_LITE_ENSURE_OK(context_, status); if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) { // Reset the state to force tensor/op reallocation. state_ = kStateUninvokable; TF_LITE_ENSURE_OK(context_, AllocateTensors()); TF_LITE_ENSURE_EQ(context_, state_, kStateInvokable); // After using a delegate which doesn't support dynamic tensors, make the // entire graph immutable. state_ = kStateInvokableAndImmutable; } return status; }
Prepare にて、flex::delegate::Prepare メソッドを設定。
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { // Get the nodes in the current execution plan. Interpreter owns this array. TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
なんやら、GetExecutionPlanなるもので、plan を獲得?
// Add all custom ops starting with "Flex" to list of supported nodes. std::vector<int> supported_nodes; for (int node_index : TfLiteIntArrayView(plan)) { TfLiteNode* node; TfLiteRegistration* registration; TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( context, node_index, &node, ®istration)); if (IsFlexOp(registration->custom_name)) { supported_nodes.push_back(node_index); } }
Flexで始まるOpを supported_nodes に登録。
// Request TFLite to partition the graph and make kernels for each independent // node sub set. TfLiteIntArray* size_and_nodes = ConvertVectorToTfLiteIntArray(supported_nodes); context->ReplaceNodeSubsetsWithDelegateKernels(context, GetKernel(), size_and_nodes, delegate); TfLiteIntArrayFree(size_and_nodes); return kTfLiteOk; }
うーん。。。
context->ReplaceNodeSubsetsWithDelegateKernels( と GetKernel()
が見つからない。頑張って、ソースコードを追ってみたら、ここだ!
Subgraph::SwitchToDelegateContextメソッドにあった。
void Subgraph::SwitchToDelegateContext() { context_->GetNodeAndRegistration = GetNodeAndRegistration; context_->ReplaceNodeSubsetsWithDelegateKernels = ReplaceNodeSubsetsWithDelegateKernels; context_->GetExecutionPlan = GetExecutionPlan; }
戻ってみると、SwitchToDelegateContextメソッドにて、context_->eplaceNodeSubsetsWithDelegateKernels に、自分の ReplaceNodeSubsetsWithDelegateKernelsメソッドを代入して、delegate->Prepare を実行している。
// TODO(aselle): Consider if it is worth storing pointers to delegates. // Setup additional context interface. SwitchToDelegateContext(); TfLiteStatus status = delegate->Prepare(context_, delegate); // Remove additional context info. SwitchToKernelContext();
Subgraph::ReplaceNodeSubsetsWithDelegateKernelsメソッドは、下記のようになっていて、
TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels( TfLiteContext* context, TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate) { return static_cast<Subgraph*>(context->impl_) ->ReplaceNodeSubsetsWithDelegateKernels(registration, nodes_to_replace, delegate); }
Subgraphクラスのコンストラクタにて、context_->impl_ に自分自身を代入している。
Subgraph::Subgraph(ErrorReporter* error_reporter, TfLiteExternalContext** external_contexts, std::vector<std::unique_ptr<Subgraph>>* subgraphs) : context_(&owned_context_), error_reporter_(error_reporter), next_execution_plan_index_to_prepare_(0), external_contexts_(external_contexts), subgraphs_(subgraphs) { context_->impl_ = static_cast<void*>(this); context_->ResizeTensor = ResizeTensor; context_->ReportError = ReportErrorC; context_->AddTensors = AddTensors; context_->tensors = nullptr; context_->tensors_size = 0; context_->allow_fp32_relax_to_fp16 = false; context_->recommended_num_threads = -1; context_->GetExternalContext = GetExternalContext; context_->SetExternalContext = SetExternalContext; // Reserve some space for the tensors to avoid excessive resizing. tensors_.reserve(kTensorsReservedCapacity); nodes_and_registration().reserve(kTensorsReservedCapacity); // Invalid to call these these except from TfLiteDelegate SwitchToKernelContext(); }
長くなったので、明日に。