先々週の土曜日のTensorFlow Lite の eager (delegate) => flex (delegate)に出てきた、「ExecuteFlexOp」はどこで呼ばれているのか?
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* op_data = reinterpret_cast<OpData*>(node->user_data); BufferMap* buffer_map = op_data->buffer_map; tensorflow::EagerContext* eager_context = op_data->eager_context; // Insert a tensor in the buffer map for all inputs that are not constant. // Constants were handled in Prepare() already. for (auto tensor_index : op_data->subgraph_inputs) { TfLiteTensor* tensor = &context->tensors[tensor_index]; if (!IsConstantTensor(tensor)) { // If this tensor is part of an earlier TF subgraph we should not add it // to the BufferMap again, because TF already knows about it and its // contents are kept automatically up-to-date. if (!buffer_map->IsTensorFlowTensor(tensor_index)) { buffer_map->SetFromTfLite(tensor_index, tensor); } } }ここまでで入力データの処理。
// Execute the TensorFlow Ops sequentially. for (const auto& node_data : op_data->nodes) { if (node_data.nodedef.op().empty()) { context->ReportError(context, "Invalid NodeDef in Flex op '%s'", node_data.name.c_str()); return kTfLiteError; } auto status = ExecuteFlexOp(eager_context, buffer_map, node_data.name, node_data.nodedef, node_data.inputs, node_data.outputs); TF_LITE_ENSURE_OK(context, ConvertStatus(context, status)); }
ここでは、ノードはSubgraphになっているので、Subgraph内のノードを逐次 ExecuteFlexOpメソッドにて実行。
for (auto tensor_index : op_data->subgraph_outputs) { if (!buffer_map->HasTensor(tensor_index)) { context->ReportError(context, "Cannot write to invalid tensor index %d", tensor_index); return kTfLiteError; } TfLiteTensor* tensor = &context->tensors[tensor_index]; TF_LITE_ENSURE_OK( context, CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor)); tensor->buffer_handle = tensor_index; tensor->data_is_stale = true; } return kTfLiteOk; }
出力データの処理
そして、この Eval メソッドは、下記の部分で定義されている GetKernel() メソッドを呼ぶことで利用できる。
TfLiteRegistration GetKernel() { TfLiteRegistration registration{&kernel::Init, &kernel::Free, &kernel::Prepare, &kernel::Eval, nullptr, kTfLiteBuiltinDelegate}; return registration; }
そして、この GetKernel() は、Prepareにて呼ばれている。
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { // Get the nodes in the current execution plan. Interpreter owns this array. TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); // Add all custom ops starting with "Flex" to list of supported nodes. std::vector<int> supported_nodes; for (int node_index : TfLiteIntArrayView(plan)) { TfLiteNode* node; TfLiteRegistration* registration; TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( context, node_index, &node, ®istration)); if (IsFlexOp(registration->custom_name)) { supported_nodes.push_back(node_index); } } // Request TFLite to partition the graph and make kernels for each independent // node sub set. TfLiteIntArray* size_and_nodes = ConvertVectorToTfLiteIntArray(supported_nodes); context->ReplaceNodeSubsetsWithDelegateKernels(context, GetKernel(), size_and_nodes, delegate); TfLiteIntArrayFree(size_and_nodes); return kTfLiteOk; }
ReplaceNodeSubsetsWithDelegateKernelsは、TensorFlow Lite の FlexDelegate (その2)に繋がる。
TfLiteRegistration GetKernel() { TfLiteRegistration registration{&kernel::Init, &kernel::Free, &kernel::Prepare, &kernel::Eval, nullptr, kTfLiteBuiltinDelegate}; return registration; }