TensorFlow Lite の delegate に、eager が追加されたのは、r1.11 の時、TensorFlowのEagerモードのコードを動かすための仕組みだと思っていた。
先週からのブログでアップしていたのが、その部分のコード。
ここにきて、TensorFlowのOpをTensorFlow Lite内で動かす仕組みだということが分かった。
ここにきて、TensorFlowのOpをTensorFlow Lite内で動かす仕組みだということが分かった。
最初のコードにもちゃんと書いてあった。
// Note: this is part of TF Lite's Eager delegation code which is to be // completed soon. // This is the TF Lite op that is created by the eager delegate to handle // execution of a supported subgraph. The usual flow is that the delegate // informs the interpreter of supported nodes in a graph, and each supported // subgraph is replaced with one instance of this kernel. // // The kernel is initialized with TfLiteDelegateParams from which we retrieve // the global EagerContext and BufferMap, as well as a list of inputs and // outputs to the subgraph. Those are used to build the OpData, with a list of // TensorFlow Ops that should be executed in order (which we call an OpNode). // // For each node included in the subgraph, we query the interpreter and // retrieve the associated NodeDef, which is then used to configure the // corresponding TensorFlow/Eager Op.
そして、Opを実行するコード(r1.11)
tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context, BufferMap* buffer_map, const string& op_name, const tensorflow::NodeDef& nodedef, const std::vector<int>& inputs, const std::vector<int>& outputs) { const tensorflow::AttrTypeMap* attr_types; TF_RETURN_IF_ERROR( tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types));
第一引数が tensorflow::EagerContext* ということがポイント。このメソッドは、TensorFlow の Eagerモードで動く。
tensorflow::EagerOperation op(eager_context, op_name.c_str(), attr_types);
ほい、来たよ。op が tensorflow::EagerOperation だよ。
for (const auto& attr : nodedef.attr()) { op.MutableAttrs()->Set(attr.first, attr.second); }
ノードの属性を op に追加。
for (int input_index : inputs) { if (!buffer_map->HasTensor(input_index)) { return tensorflow::errors::Internal("Invalid tensor index ", input_index); } auto* handle = new tensorflow::TensorHandle( buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr); op.AddInput(handle); handle->Unref(); } int num_retvals = outputs.size(); tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals( num_retvals, nullptr);
入力および出力を設定。
TF_RETURN_IF_ERROR(EagerExecute(&op, &retvals, &num_retvals));
op を実行。
if (outputs.size() != num_retvals) { return tensorflow::errors::Internal( "Unexpected number of outputs from EagerExecute"); } for (int i = 0; i < num_retvals; ++i) { const tensorflow::Tensor* tensor = nullptr; TF_RETURN_IF_ERROR(retvals[i]->Tensor(&tensor)); buffer_map->SetFromTensorFlow(outputs[i], *tensor); retvals[i]->Unref(); } return tensorflow::Status::OK(); }
TensorFlowで実行した出力値を獲得して、出力に代入。
Status EagerExecute(EagerOperation* op, gtl::InlinedVector<TensorHandle*, 2>* retvals, int* num_retvals) { TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op)); bool op_is_local = IsLocal(op->EagerContext(), op->Device()); if (op_is_local) { return EagerLocalExecute(op, retvals, num_retvals); } if (op->EagerContext()->LogDevicePlacement()) { LOG(INFO) << "Executing op " << op->Name() << " in device " << op->Device()->name(); } return EagerRemoteExecute(op, retvals->data(), num_retvals); }
これ、TensorFlow のメソッドだよ。。。