Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

TensorFlow Liteで ExecuteFlexOp を呼んでいるところ


先々週の土曜日のTensorFlow Lite の eager (delegate) => flex (delegate)に出てきた、「ExecuteFlexOp」はどこで呼ばれているのか?


TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
  BufferMap* buffer_map = op_data->buffer_map;
  tensorflow::EagerContext* eager_context = op_data->eager_context;

  // Insert a tensor in the buffer map for all inputs that are not constant.
  // Constants were handled in Prepare() already.
  for (auto tensor_index : op_data->subgraph_inputs) {
    TfLiteTensor* tensor = &context->tensors[tensor_index];
    if (!IsConstantTensor(tensor)) {
      // If this tensor is part of an earlier TF subgraph we should not add it
      // to the BufferMap again, because TF already knows about it and its
      // contents are kept automatically up-to-date.
      if (!buffer_map->IsTensorFlowTensor(tensor_index)) {
        buffer_map->SetFromTfLite(tensor_index, tensor);
      }
    }
  }
ここまでで入力データの処理。

  // Execute the TensorFlow Ops sequentially.
  for (const auto& node_data : op_data->nodes) {
    if (node_data.nodedef.op().empty()) {
      context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
                           node_data.name.c_str());
      return kTfLiteError;
    }
    auto status =
        ExecuteFlexOp(eager_context, buffer_map, node_data.name,
                      node_data.nodedef, node_data.inputs, node_data.outputs);
    TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
  }

ここでは、ノードはSubgraphになっているので、Subgraph内のノードを逐次 ExecuteFlexOpメソッドにて実行。

  for (auto tensor_index : op_data->subgraph_outputs) {
    if (!buffer_map->HasTensor(tensor_index)) {
      context->ReportError(context, "Cannot write to invalid tensor index %d",
                           tensor_index);
      return kTfLiteError;
    }

    TfLiteTensor* tensor = &context->tensors[tensor_index];
    TF_LITE_ENSURE_OK(
        context,
        CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor));
    tensor->buffer_handle = tensor_index;
    tensor->data_is_stale = true;
  }

  return kTfLiteOk;
}

出力データの処理

そして、この Eval メソッドは、下記の部分で定義されている GetKernel() メソッドを呼ぶことで利用できる。
TfLiteRegistration GetKernel() {
  TfLiteRegistration registration{&kernel::Init,    &kernel::Free,
                                  &kernel::Prepare, &kernel::Eval,
                                  nullptr,          kTfLiteBuiltinDelegate};
  return registration;
}

そして、この GetKernel() は、Prepareにて呼ばれている。
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
  // Get the nodes in the current execution plan. Interpreter owns this array.
  TfLiteIntArray* plan;
  TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));

  // Add all custom ops starting with "Flex" to list of supported nodes.
  std::vector<int> supported_nodes;
  for (int node_index : TfLiteIntArrayView(plan)) {
    TfLiteNode* node;
    TfLiteRegistration* registration;
    TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
        context, node_index, &node, ®istration));

    if (IsFlexOp(registration->custom_name)) {
      supported_nodes.push_back(node_index);
    }
  }

  // Request TFLite to partition the graph and make kernels for each independent
  // node sub set.
  TfLiteIntArray* size_and_nodes =
      ConvertVectorToTfLiteIntArray(supported_nodes);
  context->ReplaceNodeSubsetsWithDelegateKernels(context, GetKernel(),
                                                 size_and_nodes, delegate);
  TfLiteIntArrayFree(size_and_nodes);
  return kTfLiteOk;
}


ReplaceNodeSubsetsWithDelegateKernelsは、TensorFlow Lite の FlexDelegate (その2)に繋がる。

このSubgraphを含むノードを BuiltinOperator_DELEGATEとして扱うのね。
そして、GetKernel()での、kTfLiteBuiltinDelegate になるのね。

TfLiteRegistration GetKernel() {
  TfLiteRegistration registration{&kernel::Init,    &kernel::Free,
                                  &kernel::Prepare, &kernel::Eval,
                                  nullptr,          kTfLiteBuiltinDelegate};
  return registration;
}