Vengineerの妄想

人生を妄想しています。

TensorFlow Lite の eager (delegate) => flex (delegate)


TensorFlow Lite の delegate に、eager が追加されたのは、r1.11 の時、TensorFlowのEagerモードのコードを動かすための仕組みだと思っていた。

先週からのブログでアップしていたのが、その部分のコード。
ここにきて、TensorFlowのOpをTensorFlow Lite内で動かす仕組みだということが分かった。

最初のコードにもちゃんと書いてあった。

// Note: this is part of TF Lite's Eager delegation code which is to be
// completed soon.

// This is the TF Lite op that is created by the eager delegate to handle
// execution of a supported subgraph. The usual flow is that the delegate
// informs the interpreter of supported nodes in a graph, and each supported
// subgraph is replaced with one instance of this kernel.
//
// The kernel is initialized with TfLiteDelegateParams from which we retrieve
// the global EagerContext and BufferMap, as well as a list of inputs and
// outputs to the subgraph. Those are used to build the OpData, with a list of
// TensorFlow Ops that should be executed in order (which we call an OpNode).
//
// For each node included in the subgraph, we query the interpreter and
// retrieve the associated NodeDef, which is then used to configure the
// corresponding TensorFlow/Eager Op.

そして、Opを実行するコード(r1.11)
tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context,
                                  BufferMap* buffer_map, const string& op_name,
                                  const tensorflow::NodeDef& nodedef,
                                  const std::vector<int>& inputs,
                                  const std::vector<int>& outputs) {
  const tensorflow::AttrTypeMap* attr_types;
  TF_RETURN_IF_ERROR(
      tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types));

第一引数が tensorflow::EagerContext* ということがポイント。このメソッドは、TensorFlow の Eagerモードで動く。

  tensorflow::EagerOperation op(eager_context, op_name.c_str(), attr_types);

ほい、来たよ。op が tensorflow::EagerOperation だよ。

  for (const auto& attr : nodedef.attr()) {
    op.MutableAttrs()->Set(attr.first, attr.second);
  }

ノードの属性を op に追加。

  for (int input_index : inputs) {
    if (!buffer_map->HasTensor(input_index)) {
      return tensorflow::errors::Internal("Invalid tensor index ", input_index);
    }
    auto* handle = new tensorflow::TensorHandle(
        buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr);
    op.AddInput(handle);
    handle->Unref();
  }

  int num_retvals = outputs.size();
  tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals(
      num_retvals, nullptr);

入力および出力を設定。

  TF_RETURN_IF_ERROR(EagerExecute(&op, &retvals, &num_retvals));

op を実行。

  if (outputs.size() != num_retvals) {
    return tensorflow::errors::Internal(
        "Unexpected number of outputs from EagerExecute");
  }

  for (int i = 0; i < num_retvals; ++i) {
    const tensorflow::Tensor* tensor = nullptr;
    TF_RETURN_IF_ERROR(retvals[i]->Tensor(&tensor));
    buffer_map->SetFromTensorFlow(outputs[i], *tensor);
    retvals[i]->Unref();
  }


  return tensorflow::Status::OK();
}


TensorFlowで実行した出力値を獲得して、出力に代入。


Status EagerExecute(EagerOperation* op,
                    gtl::InlinedVector<TensorHandle*, 2>* retvals,
                    int* num_retvals) {
  TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));

  bool op_is_local = IsLocal(op->EagerContext(), op->Device());

  if (op_is_local) {
    return EagerLocalExecute(op, retvals, num_retvals);
  }

  if (op->EagerContext()->LogDevicePlacement()) {
    LOG(INFO) << "Executing op " << op->Name() << " in device "
              << op->Device()->name();
  }

  return EagerRemoteExecute(op, retvals->data(), num_retvals);
}

これ、TensorFlow のメソッドだよ。。。