Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

「PyTorch + XLA」のソースコード解析 (その3)


昨日、[ 「PyTorch + XLA」のソースコード解析 (その2)]の続きで、6章と7章。。

===================================================================================

6)、ComputationClientでの実行

XlaModule::Executeメソッド内で、
以下のように、xla::ComputationClient の ExecutionComputation メソッド か ExecuteReplicated が実行されます。

 auto client = XlaGetClient();
  std::vector<std::shared_ptr<xla::ComputationClient::Data>> exec_results;
  if (inputs.size() == 1) {
    exec_results.push_back(
        client->ExecuteComputation(computation, inputs.front(), &result_shape));
  } else {
    exec_results =
        client->ExecuteReplicated(computation, inputs, &result_shape);
  }
  return DecomposeComputationResult(std::move(exec_results), result_shape,
                                    module_id);

XlaGetClient が返す client の値は、「5)、実行環境の決定」 で説明しました XlaComputationClient と XrtComputationClient の2つのケースです。それぞれのケースについて、見ていきたいと思います。


6.1)、XlaComputationClientのケース

XlaComputationClient の場合では、XLAのデバイス(CPUとTPU、GPUが利用できるかは確認できていません)を利用します。XlaComputaionClientでは、client_->Execute を実行します。

std::shared_ptr<ComputationClient::Data>
XlaComputationClient::ExecuteComputation(
    const XlaComputation& computation,
    tensorflow::gtl::ArraySlice<Data*> arguments, const Shape* output_shape) {
  metrics::TimedSection timed(ExecuteMetric());
  FlushReleasedHandles();

  string device;
  // 入力データを <GlobalData> に変換します
  std::vector<GlobalData*> arguments_data =
      GetArgumentsData(arguments, &device);
  ExecutionOptions eo;
  *eo.mutable_debug_options() = legacy_flags::GetDebugOptionsFromFlags();
  *eo.add_device_handles() = GetDeviceHandle(device);
  if (output_shape != nullptr) {
    *eo.mutable_shape_with_output_layout() = *output_shape;
  }
  // この部分で、XLA 経由でデバイス上で実行します
  StatusOr<std::unique_ptr<GlobalData>> result_or_status =
      client_->Execute(computation, arguments_data, &eo);
  xrt_util::CheckComputationStatus(result_or_status.status(), {&computation});

  ProgramShape program_shape;
  if (output_shape == nullptr) {
    program_shape = computation.GetProgramShape().ValueOrDie();
    output_shape = &program_shape.result();
  }
  // Executeの戻り値を XlaData に変換します
  return std::make_shared<XlaData>(
      std::move(result_or_status.ValueOrDie()), device, *output_shape,
      [this](XlaData* xla_data) { ReleaseXlaData(xla_data); });
}

TPUを利用する場合は、

export XLA_USE_XRT=0 export XLA_GRPC_HOST="" XLA_PLATFORM="CPU"

を ComputationClient::Create() メソッド の

    XlaComputationClient::Options options;
    options.host_name = GetEnvString("XLA_GRPC_HOST", "localhost");
    options.port = GetEnvInt("XLA_GRPC_PORT", 51000);
    options.platform = GetEnvString("XLA_PLATFORM", "TPU");
    client.reset(new XlaComputationClient(options));

にあるような環境変数を設定すればいいと思う。
つまり、環境変数 XLA_PLATFORM を “TPU” にするか、環境変数 XLA_PLATFORM を設定しない。


ExecuteReplicated メソッド は、まだ、サポートされています。

std::vector<std::shared_ptr<ComputationClient::Data>>
XlaComputationClient::ExecuteReplicated(
    const XlaComputation& computation,
    const std::vector<std::vector<Data*>>& arguments,
    const Shape* output_shape) {
  metrics::TimedSection timed(ExecuteReplicatedMetric());
  LOG(FATAL) << "ExecuteReplicated() API not yet implemented!";
}


6.2)、XrtComputationClientのケース

XrtComputationClient の場合では、CPUとTPUが利用できます。session->session.Run で実行されます。

std::shared_ptr<ComputationClient::Data>
XrtComputationClient::ExecuteComputation(
    const XlaComputation& computation,
    tensorflow::gtl::ArraySlice<Data*> arguments, const Shape* output_shape) {
  metrics::TimedSection timed(ExecuteMetric());
  ApiCallInitialize();

  std::vector<string> devices;
  tensorflow::ClientSession::FeedType feed_inputs;
  std::vector<ExecuteContext> exec_ops =
      CreateExecuteOps(computation, BuildParallelArguments(arguments),
                       output_shape, &devices, &feed_inputs);
  SessionData* session = GetSessionForDevice(devices.front());
  std::vector<tensorflow::Tensor> outputs;
  TF_CHECK_OK(session->root.status());
  xrt_util::CheckComputationStatus(
      session->session.Run(feed_inputs, {exec_ops.front().execute_output},
                           &outputs),
      {&computation});
  XLA_CHECK_EQ(outputs.size(), 1);

  return std::make_shared<XrtData>(
      devices.front(), outputs[0].scalar<int64>()(),
      exec_ops.front().result_shape,
      [this](XrtData* xrt_data) { ReleaseXrtData(xrt_data); });
}

XrtComputationClient (TPU)では、ExecuteReplicated メソッドをサポートしています。

std::vector<std::shared_ptr<ComputationClient::Data>>
XrtComputationClient::ExecuteReplicated(
    const XlaComputation& computation,
    const std::vector<std::vector<Data*>>& arguments,
    const Shape* output_shape) {
  metrics::TimedSection timed(ExecuteReplicatedMetric());
  ApiCallInitialize();

  std::vector<string> devices;
  tensorflow::ClientSession::FeedType feed_inputs;
  std::vector<ExecuteContext> exec_ops = CreateExecuteOps(
      computation, arguments, output_shape, &devices, &feed_inputs);
  return RunComputations(exec_ops, {&computation}, devices, feed_inputs);
}

CreateExecuteOpsメソッドでは、std::vector<XrtComputationClient::ExecuteContext> を生成します。
生成された std::vector<XrtComputationClient::ExecuteContext> は、RunComputations メソッドにて、デバイスの数、Replicaを作って実行するようです。つまり、実行環境にあわせて Replica を勝手に作ってくれるようです。なので、TPUがたくさんあると、それなりに速くなるのでしょうかね。。。

std::vector<XrtComputationClient::ExecuteContext>
XrtComputationClient::CreateExecuteOps(
    const XlaComputation& computation,
    const std::vector<std::vector<Data*>>& arguments, const Shape* output_shape,
    std::vector<string>* devices,
    tensorflow::ClientSession::FeedType* feed_inputs) {
  ProgramShape program_shape;
  if (output_shape == nullptr) {
    program_shape = computation.GetProgramShape().ValueOrDie();
    output_shape = &program_shape.result();
  }
  *devices = GetReplicasDevices(arguments);
  auto xrt_computation = CreateXrtComputation(computation, arguments.size(),
                                              *devices, output_shape);

  absl::optional<tensorflow::ops::Placeholder> computation_holder;
  std::vector<ExecuteContext> exec_ops;
  for (size_t i = 0; i < arguments.size(); ++i) {
    auto inputs = GetArgumentsInputs(arguments[i], devices->at(i), feed_inputs);
    const string& xrt_device = TorchDeviceToXrtDevice(devices->at(i));

    // この部分が、XRT に関する部分になるようです。
    SessionData* session = GetSessionForXrtDevice(xrt_device);
    tensorflow::Scope device_scope = session->root.WithDevice(xrt_device);
    const CachedNode& cached_node =
        GetCompileExecuteNode(device_scope, devices->at(i));
    // GetCompileExecuteNode メソッドにて、
    // XRTCompile Op から XRTExecute Op につながるようになっています。
    // ここが、XRT のポイントなんでしょうか?

const XrtComputationClient::CachedNode&
XrtComputationClient::GetCompileExecuteNode(const tensorflow::Scope& scope,
                                            const string& device) {
  NodeCache* cache =
      &node_cache_[NodeCacheKey(device, NodeTypes::kCompileExecute)];
  if (cache->empty()) {
    std::vector<tensorflow::ops::Placeholder> holders(
        {tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING),
         tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING),
         tensorflow::ops::Placeholder(
             scope, tensorflow::DT_INT64,
             tensorflow::ops::Placeholder::Shape({-1}))});
    auto computation_handle = tensorflow::ops::XRTCompile(scope, holders[0]);
    std::unique_ptr<CachedNode> node(new CachedNode(
        tensorflow::ops::XRTExecute(scope, computation_handle.handle,
                                    holders[1],
                                    {tensorflow::Output(holders[2])}),
        std::move(holders)));
    cache->add(std::move(node));
  }
  return cache->get();
}

    feed_inputs->insert(
        {cached_node.holders[0], xrt_computation->SerializeAsString()});

    xrt::XRTExecutionConfig exec_config;
    exec_config.set_core_index_in_replica(0);
    exec_config.set_release_input_handles(false);
    exec_config.set_release_compilation_handle(true);
    feed_inputs->insert(
        {cached_node.holders[1], exec_config.SerializeAsString()});
    feed_inputs->insert({cached_node.holders[2], inputs});

    exec_ops.emplace_back(*cached_node.output, *output_shape);
  }
  return exec_ops;
}

CreateXrtComputation メソッドは、xla::XlaComputation から xrt::XlaComputation を生成します。

std::unique_ptr<xrt::XLAComputation> XrtComputationClient::CreateXrtComputation(
    const XlaComputation& computation, int64 num_replicas,
    const std::vector<string>& devices, const Shape* output_shape) const {
  XLA_CHECK_EQ(num_replicas, devices.size());
  std::unique_ptr<xrt::XLAComputation> xrt_computation(
      new xrt::XLAComputation());
  auto config = xrt_computation->mutable_config();
  config->set_num_replicas(num_replicas);
  config->set_num_cores_per_replica(1);
  if (num_replicas > 1) {
    auto device_assignment = config->mutable_device_assignment();
    auto computation_device = device_assignment->add_computation_devices();
    for (int64 i = 0; i < num_replicas; ++i) {
      const string& xrt_device = TorchDeviceToXrtDevice(devices[i]);
      const auto& core_coords = GetDeviceMeshCoords(xrt_device);
      auto replica_device = computation_device->add_replica_devices();
      for (auto coord : core_coords) {
        replica_device->add_value(coord);
      }
    }
  }
  *config->mutable_program_shape() = computation.GetProgramShape().ValueOrDie();
  if (output_shape != nullptr) {
    *config->mutable_program_shape()->mutable_result() = *output_shape;
  }
  *xrt_computation->mutable_hlo_snapshot() =
      *computation.Snapshot().ValueOrDie();
  return xrt_computation;
}

CreateExecuteOpsメソッドの戻り値である exec_ops を RunComputationsメソッドに渡します。sess_replica.first->session.Run で実行されます。

std::vector<std::shared_ptr<ComputationClient::Data>>
XrtComputationClient::RunComputations(
    const std::vector<ExecuteContext>& exec_ops,
    tensorflow::gtl::ArraySlice<const XlaComputation* const> computations,
    const std::vector<string>& devices,
    const tensorflow::ClientSession::FeedType& feed_inputs) {
  // 「PyTorch + XLA」のメインとなる部分だと思います。
  // この部分が XLA ではなく、 XRT を使うメリットなのでしょうかね。
  // In the PyTorch/XRT interface we keep a map (options_.workers_map) from a
  // worker+taskno, to the GRPC server which is the entry point for that worker.
  // Since XRT could re-distribute ops internally, if we have N hosts
  // (worker+taskno), we could have all the workers pointing to a single GRPC
  // entry point, or we could have each worker pointing directly to the target
  // host.
  // The advantage of the latter approach, is that we do not bottleneck
  // (especially when feeding inputs) the single GRPC entry point.
  // Using the N:1 approach, the session_replicas below will contain a single
  // session, and all the replica executions will go through it (and distributed
  // by XRT on the service side).
  // Chosing the 1:1 approach (one session per worker), we will have N sessions
  // within the session_replicas map, which we will be executing independently.
  std::map<SessionData*, std::vector<size_t>> session_replicas;
  for (size_t i = 0; i < devices.size(); ++i) {
    SessionData* session = GetSessionForDevice(devices[i]);
    session_replicas[session].push_back(i);
  }
  // TODO(dlibenzi): These could be run in parallel.
 // devices.size() 分、Replicas を作って実行している
 // この部分がスケールすると、嬉しいということですね。
  std::vector<std::shared_ptr<Data>> results(devices.size());
  for (auto& sess_replica : session_replicas) {
    std::vector<tensorflow::Output> exec_nodes;
    for (auto replica : sess_replica.second) {
      exec_nodes.push_back(exec_ops[replica].execute_output);
    }
    std::vector<tensorflow::Tensor> outputs;
    TF_CHECK_OK(sess_replica.first->root.status());
    xrt_util::CheckComputationStatus(
        sess_replica.first->session.Run(feed_inputs, exec_nodes, &outputs),
        computations);
    // ここで、session.Run を実行する
    XLA_CHECK_EQ(outputs.size(), exec_nodes.size());

    for (size_t i = 0; i < outputs.size(); ++i) {
      auto replica = sess_replica.second[i];
      results[replica] = std::make_shared<XrtData>(
          devices[replica], outputs[i].scalar<int64>()(),
          exec_ops[replica].result_shape,
          [this](XrtData* xrt_data) { ReleaseXrtData(xrt_data); });
    }
  }
  return results;
}

簡単にまとめると、

 ・XlaComputation の時 (XLA_USE_XRT=0)は、session->session.Run を実行
 ・XrtComputation の時 (XLA_USE_XRT=1)は、sess_replica.first->session.Run を実行

 Clound TPU の場合は、TPUの数分、replica を作って、並列処理を行える!


7)、backward メソッド

テストコード内の test_nll_loss というテストでは、backward メソッドを呼んでいます。
また、XXXGrad(例えば、AvgPoolGrad クラス)と名前のクラスでも同じように、backward メソッドが呼ばれています。

   def test_nll_loss(self):
        input = torch.randn(3, 5, requires_grad=True)
        target = torch.empty(3, dtype=torch.long).random_(5)
        model = XlaNllLoss()
        traced_model = torch.jit.trace(model, (input, target))
        xla_model = torch_xla._C.XlaModule(traced_model)
        xla_inputs = [torch_xla._C.XLATensor(input), 
                      torch_xla._C.XLATensor(target)]
        output_xla = xla_model((tuple(xla_inputs)))
        xla_model.backward(*output_xla)
        output = model(input, target)
        output.backward()
        self.assertEqual(input.grad.data, xla_inputs[0].grad.data.to_tensor())

forward メソッドについては、これまでに説明してきましたが、backward メソッドについては、この章で説明します。backward メソッドは、forward メソッドに比べて非常に長いです。backward メソッドは、InitXlaModuleBindingsメソッドの中のXlaModuleクラスのメソッド ( backward ) として定義されています。

     .def("backward",
           [](XlaModule& xla_module, py::args args) {
             auto inputs = XlaCreateTensorList(args);
             xla_module.backward(inputs);
           })

backwardメソッドでは、C++のXlaModuleの backward メソッドを呼んでいます。

forward側のXlaComputation (forward_computation)の生成は、 グラフ (f_) をベースに行います。

    XlaTranslator xla_fwd_impl(f_, GetPrecisionConfig());
    forward_computation_ = xla_fwd_impl.BuildComputation(forward_shapes);

一方、backward側のXlaComputation (back_computation)の生成は、 グラフ (df_) をベースに行います。

    XlaTranslator xla_bwd_impl(df_, GetPrecisionConfig());
    backward_computation_ = xla_bwd_impl.BuildComputation(
        backward_shapes, GetBackwardBuildOptions(0, inputs_.size()));

forward と backward では違うグラフ ( f_ と df_ )にて、XlaComputationを生成しています。では、このグラフ ( f_ と df_ )はどのように求めたのでしょうか? グラフ ( f_ と df_ )は、XlaModule::Initialize メソッドにて、次のように設定されています。

void XlaModule::Initialize(const TensorBatchVector& inputs) {
 if (script_module_ == nullptr) {
    return;
  }

  // Get forward graph.
  const auto forward = script_module_->find_method("forward");
  // どうやら、”forward”メソッドを見つけているようだ
  JIT_ASSERT(forward);
  std::shared_ptr<Graph> forward_graph = forward->graph()->copy();
  auto forward_graph_copy = forward_graph->copy();
  // 見つけたものからグラフをゲットする
  Gradient gradient = differentiate(forward_graph_copy);
  // その微分を求める

  // 途中略

  // Take ownership of the forward and differentiated graphs and release the
  // reference to the script module to mark initialization as done.
  f_ = gradient.f;
  df_ = gradient.df;
  // 求めた微分からグラフ ( f と df ) を獲得している
  // Mark the module as initialized.
  script_module_ = nullptr;
}

forwardメソッドを見つけるためのscript_module_ は、コンストラクタにて module が設定されています。

XlaModule::XlaModule(const std::shared_ptr<script::Module> module,
                     bool use_full_conv_precision, bool differentiate)
    : use_full_conv_precision_(use_full_conv_precision),
      enable_trace_fusion_(differentiate),
      differentiate_(differentiate),
      module_id_(s_module_id_++),
      script_module_(module) {}

テストコードの torch_xla._C.XlaModule の引数である traced_model が script_module_ になります。

class TestMulAdd(XlaTestCase):
    def test(self):


       class XlaMulAdd(nn.Module):
            def forward(self, x, y):
                return x * y + y

        x = torch.rand(3, 5)
        y = torch.rand(3, 5)
        model = XlaMulAdd()
        traced_model = torch.jit.trace(model, (x, y))
        xla_model = torch_xla._C.XlaModule(traced_model)
        inputs_xla = [torch_xla._C.XLATensor(x), torch_xla._C.XLATensor(y)]
        output_xla = xla_model((tuple(inputs_xla)))
        expected = model(x, y)
        self.assertEqualDbg(output_xla[0][0].to_tensor().data, expected.data)

const auto forward = script_module_->find_method("forward";);

は、上記のテストコードの TestMulAdd クラスでは、XlaMulAdd クラスの forward メソッドになります。

  std::shared_ptr<Graph> forward_graph = forward->graph()->copy();
  auto forward_graph_copy = forward_graph->copy();
  Gradient gradient = differentiate(forward_graph_copy);

は、foward メソッドのグラフから diffrerentiate メソッドにて、Gradient を求めていることになります。

differentiate メソッドは、PyTorchの autodiff.cpp の中で次のように定義されています。

Gradient differentiate(std::shared_ptr<Graph>& graph) {
  Gradient grad_desc;
  // Take ownership of the graph
  JIT_ASSERTM(graph.use_count() == 1,
              "differentiate will mutate and destroy the graph, so it requires "
              "graph.use_count() == 1, but found %d", graph.use_count());
  std::swap(graph, grad_desc.f);
  // XXX: Take care when handling outputs - they can be duplicated!

  WithInsertPoint guard(grad_desc.f->block());
  // Fills in df_input_vjps and df_output_vjps
  auto rev_info = addReverseInline(grad_desc);
  // Lift constants captured for the reverse graph into it
  liftConstants(grad_desc, rev_info);
  // addReverseInline has to call gradientForNode if *any* of the outputs
  // require grad, but it will emit vjps for *all* outputs. Use DCE to remove
  // unnecessary nodes.
  EliminateDeadCode(rev_info.reverse_block);
  // Fills in f, df, f_real_outputs, df_input_captures,
  // modifies df_input_vjps (new vjps are added for temporaries)
  lambdaLiftReverse(grad_desc, rev_info);
  return grad_desc;
}

ちなみに、あたしは、微分プログラミングとか、auto.diff とか、全く分からないのですよ。
どうしたら、いいのでしょうか?