Vengineerの妄想(準備期間)

人生は短いけど、長いです。人生を楽しみましょう!

TensorFlow Lite の FlexDelegate (その1)


昨日のTensorFlow LiteのInterpreterBuilderの最後に出てきた AcquireFlexDelegateメソッド から FlexDelegateクラスについて。
// Corresponding weak declaration found in lite/model.cc.
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>
AcquireFlexDelegate() {
  return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
      tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
        delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
      });
}

std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
  std::unique_ptr<flex::DelegateData> delegate_data;
  if (!flex::DelegateData::Create(&delegate_data).ok()) {
    fprintf(stderr, "Unable to initialize TensorFlow context.\n");
    return nullptr;
  }

  return std::unique_ptr<FlexDelegate>(
      new FlexDelegate(std::move(delegate_data)));
}

FlexDelegate::FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data)
    : TfLiteDelegate(TfLiteDelegateCreate()),
      delegate_data_(std::move(delegate_data)) {
  data_ = delegate_data_.get();
  Prepare = &flex::delegate::Prepare;
  CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle;
  flags = kTfLiteDelegateFlagsAllowDynamicTensors;
}

にて、FlexDelegateクラスのインスタンスを生成。

FlexDelgateクラス
// WARNING: This is an experimental interface that is subject to change.
// Delegate that can be used to extract parts of a graph that are designed to be
// executed by TensorFlow's runtime via Eager.

およ、

executed by TensorFlow's runtime via Eager

とある。。。

//
// The interpreter must be constructed after the FlexDelegate and destructed
// before the FlexDelegate. This delegate may be used with multiple
// interpreters, but it is *not* thread-safe.
//
// Usage:
//   auto delegate = FlexDelegate::Create();
//   ... build interpreter ...
//
//   if (delegate) {
//     interpreter->ModifyGraphWithDelegate(
//         delegate.get(), /*allow_dynamic_tensors=*/true);
//   }

interpreter->ModifyGraphWithDelegate にて、グラフを変換しているのか?

TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
  return primary_subgraph().ModifyGraphWithDelegate(delegate);
}

で、

  Subgraph& primary_subgraph() {
    return *subgraphs_.front();  // Safe as subgraphs_ always has 1 entry.
  }

なので、Subgraph の ModifyGraphWithDelegate メソッドを呼んでいるようです。

TfLiteStatus Subgraph::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
  if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) {
    int last_execution_plan_index_prepared;
    TF_LITE_ENSURE_OK(&context_, PrepareOpsStartingAt(
                                     0, &last_execution_plan_index_prepared));
    if (has_dynamic_tensors_) {
      ReportError(
          "Attempting to use a delegate that only supports static-sized "
          "tensors with a graph that has dynamic-sized tensors.");
      return kTfLiteError;
    }
  }

  // TODO(aselle): Consider if it is worth storing pointers to delegates.
  // Setup additional context interface.
  SwitchToDelegateContext();

  TfLiteStatus status = delegate->Prepare(context_, delegate);

  // Remove additional context info.
  SwitchToKernelContext();

ここで、Context を Delegate に切り替える、Prepare を実行するようです。

  TF_LITE_ENSURE_OK(context_, status);

  if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) {
    // Reset the state to force tensor/op reallocation.
    state_ = kStateUninvokable;
    TF_LITE_ENSURE_OK(context_, AllocateTensors());
    TF_LITE_ENSURE_EQ(context_, state_, kStateInvokable);
    // After using a delegate which doesn't support dynamic tensors, make the
    // entire graph immutable.
    state_ = kStateInvokableAndImmutable;
  }

  return status;
}


Prepare にて、flex::delegate::Prepare メソッドを設定。

TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
  // Get the nodes in the current execution plan. Interpreter owns this array.
  TfLiteIntArray* plan;
  TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));

なんやら、GetExecutionPlanなるもので、plan を獲得?

  // Add all custom ops starting with "Flex" to list of supported nodes.
  std::vector<int> supported_nodes;
  for (int node_index : TfLiteIntArrayView(plan)) {
    TfLiteNode* node;
    TfLiteRegistration* registration;
    TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
        context, node_index, &node, ®istration));

    if (IsFlexOp(registration->custom_name)) {
      supported_nodes.push_back(node_index);
    }
  }

Flexで始まるOpを supported_nodes に登録。

  // Request TFLite to partition the graph and make kernels for each independent
  // node sub set.
  TfLiteIntArray* size_and_nodes =
      ConvertVectorToTfLiteIntArray(supported_nodes);
  context->ReplaceNodeSubsetsWithDelegateKernels(context, GetKernel(),
                                                 size_and_nodes, delegate);
  TfLiteIntArrayFree(size_and_nodes);
  return kTfLiteOk;
}

うーん。。。

 context->ReplaceNodeSubsetsWithDelegateKernels(
と
  GetKernel()

が見つからない。頑張って、ソースコードを追ってみたら、ここだ!

Subgraph::SwitchToDelegateContextメソッドにあった。

void Subgraph::SwitchToDelegateContext() {
  context_->GetNodeAndRegistration = GetNodeAndRegistration;
  context_->ReplaceNodeSubsetsWithDelegateKernels =
      ReplaceNodeSubsetsWithDelegateKernels;
  context_->GetExecutionPlan = GetExecutionPlan;
}

戻ってみると、SwitchToDelegateContextメソッドにて、context_->eplaceNodeSubsetsWithDelegateKernels に、自分の ReplaceNodeSubsetsWithDelegateKernelsメソッドを代入して、delegate->Prepare を実行している。

  // TODO(aselle): Consider if it is worth storing pointers to delegates.
  // Setup additional context interface.
  SwitchToDelegateContext();

  TfLiteStatus status = delegate->Prepare(context_, delegate);

  // Remove additional context info.
  SwitchToKernelContext();

Subgraph::ReplaceNodeSubsetsWithDelegateKernelsメソッドは、下記のようになっていて、

TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels(
    TfLiteContext* context, TfLiteRegistration registration,
    const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate) {
  return static_cast<Subgraph*>(context->impl_)
      ->ReplaceNodeSubsetsWithDelegateKernels(registration, nodes_to_replace,
                                              delegate);
}

Subgraphクラスのコンストラクタにて、context_->impl_ に自分自身を代入している。
Subgraph::Subgraph(ErrorReporter* error_reporter,
                   TfLiteExternalContext** external_contexts,
                   std::vector<std::unique_ptr<Subgraph>>* subgraphs)
    : context_(&owned_context_),
      error_reporter_(error_reporter),
      next_execution_plan_index_to_prepare_(0),
      external_contexts_(external_contexts),
      subgraphs_(subgraphs) {
  context_->impl_ = static_cast<void*>(this);
  context_->ResizeTensor = ResizeTensor;
  context_->ReportError = ReportErrorC;
  context_->AddTensors = AddTensors;
  context_->tensors = nullptr;
  context_->tensors_size = 0;
  context_->allow_fp32_relax_to_fp16 = false;
  context_->recommended_num_threads = -1;
  context_->GetExternalContext = GetExternalContext;
  context_->SetExternalContext = SetExternalContext;

  // Reserve some space for the tensors to avoid excessive resizing.
  tensors_.reserve(kTensorsReservedCapacity);
  nodes_and_registration().reserve(kTensorsReservedCapacity);
  // Invalid to call these these except from TfLiteDelegate
  SwitchToKernelContext();
}

長くなったので、明日に。