昨日、[ 「PyTorch + XLA」のソースコード解析 (その2)]の続きで、6章と7章。。
===================================================================================
6)、ComputationClientでの実行
XlaModule::Executeメソッド内で、
以下のように、xla::ComputationClient の ExecutionComputation メソッド か ExecuteReplicated が実行されます。
以下のように、xla::ComputationClient の ExecutionComputation メソッド か ExecuteReplicated が実行されます。
auto client = XlaGetClient(); std::vector<std::shared_ptr<xla::ComputationClient::Data>> exec_results; if (inputs.size() == 1) { exec_results.push_back( client->ExecuteComputation(computation, inputs.front(), &result_shape)); } else { exec_results = client->ExecuteReplicated(computation, inputs, &result_shape); } return DecomposeComputationResult(std::move(exec_results), result_shape, module_id);
XlaGetClient が返す client の値は、「5)、実行環境の決定」 で説明しました XlaComputationClient と XrtComputationClient の2つのケースです。それぞれのケースについて、見ていきたいと思います。
6.1)、XlaComputationClientのケース
XlaComputationClient の場合では、XLAのデバイス(CPUとTPU、GPUが利用できるかは確認できていません)を利用します。XlaComputaionClientでは、client_->Execute を実行します。
std::shared_ptr<ComputationClient::Data> XlaComputationClient::ExecuteComputation( const XlaComputation& computation, tensorflow::gtl::ArraySlice<Data*> arguments, const Shape* output_shape) { metrics::TimedSection timed(ExecuteMetric()); FlushReleasedHandles(); string device; // 入力データを <GlobalData> に変換します std::vector<GlobalData*> arguments_data = GetArgumentsData(arguments, &device); ExecutionOptions eo; *eo.mutable_debug_options() = legacy_flags::GetDebugOptionsFromFlags(); *eo.add_device_handles() = GetDeviceHandle(device); if (output_shape != nullptr) { *eo.mutable_shape_with_output_layout() = *output_shape; } // この部分で、XLA 経由でデバイス上で実行します StatusOr<std::unique_ptr<GlobalData>> result_or_status = client_->Execute(computation, arguments_data, &eo); xrt_util::CheckComputationStatus(result_or_status.status(), {&computation}); ProgramShape program_shape; if (output_shape == nullptr) { program_shape = computation.GetProgramShape().ValueOrDie(); output_shape = &program_shape.result(); } // Executeの戻り値を XlaData に変換します return std::make_shared<XlaData>( std::move(result_or_status.ValueOrDie()), device, *output_shape, [this](XlaData* xla_data) { ReleaseXlaData(xla_data); }); }
TPUを利用する場合は、
export XLA_USE_XRT=0 export XLA_GRPC_HOST="" XLA_PLATFORM="CPU"
を ComputationClient::Create() メソッド の
XlaComputationClient::Options options; options.host_name = GetEnvString("XLA_GRPC_HOST", "localhost"); options.port = GetEnvInt("XLA_GRPC_PORT", 51000); options.platform = GetEnvString("XLA_PLATFORM", "TPU"); client.reset(new XlaComputationClient(options));
ExecuteReplicated メソッド は、まだ、サポートされています。
std::vector<std::shared_ptr<ComputationClient::Data>> XlaComputationClient::ExecuteReplicated( const XlaComputation& computation, const std::vector<std::vector<Data*>>& arguments, const Shape* output_shape) { metrics::TimedSection timed(ExecuteReplicatedMetric()); LOG(FATAL) << "ExecuteReplicated() API not yet implemented!"; }
6.2)、XrtComputationClientのケース
XrtComputationClient の場合では、CPUとTPUが利用できます。session->session.Run で実行されます。
std::shared_ptr<ComputationClient::Data> XrtComputationClient::ExecuteComputation( const XlaComputation& computation, tensorflow::gtl::ArraySlice<Data*> arguments, const Shape* output_shape) { metrics::TimedSection timed(ExecuteMetric()); ApiCallInitialize(); std::vector<string> devices; tensorflow::ClientSession::FeedType feed_inputs; std::vector<ExecuteContext> exec_ops = CreateExecuteOps(computation, BuildParallelArguments(arguments), output_shape, &devices, &feed_inputs); SessionData* session = GetSessionForDevice(devices.front()); std::vector<tensorflow::Tensor> outputs; TF_CHECK_OK(session->root.status()); xrt_util::CheckComputationStatus( session->session.Run(feed_inputs, {exec_ops.front().execute_output}, &outputs), {&computation}); XLA_CHECK_EQ(outputs.size(), 1); return std::make_shared<XrtData>( devices.front(), outputs[0].scalar<int64>()(), exec_ops.front().result_shape, [this](XrtData* xrt_data) { ReleaseXrtData(xrt_data); }); }
XrtComputationClient (TPU)では、ExecuteReplicated メソッドをサポートしています。
std::vector<std::shared_ptr<ComputationClient::Data>> XrtComputationClient::ExecuteReplicated( const XlaComputation& computation, const std::vector<std::vector<Data*>>& arguments, const Shape* output_shape) { metrics::TimedSection timed(ExecuteReplicatedMetric()); ApiCallInitialize(); std::vector<string> devices; tensorflow::ClientSession::FeedType feed_inputs; std::vector<ExecuteContext> exec_ops = CreateExecuteOps( computation, arguments, output_shape, &devices, &feed_inputs); return RunComputations(exec_ops, {&computation}, devices, feed_inputs); }
CreateExecuteOpsメソッドでは、std::vector<XrtComputationClient::ExecuteContext> を生成します。
生成された std::vector<XrtComputationClient::ExecuteContext> は、RunComputations メソッドにて、デバイスの数、Replicaを作って実行するようです。つまり、実行環境にあわせて Replica を勝手に作ってくれるようです。なので、TPUがたくさんあると、それなりに速くなるのでしょうかね。。。
生成された std::vector<XrtComputationClient::ExecuteContext> は、RunComputations メソッドにて、デバイスの数、Replicaを作って実行するようです。つまり、実行環境にあわせて Replica を勝手に作ってくれるようです。なので、TPUがたくさんあると、それなりに速くなるのでしょうかね。。。
std::vector<XrtComputationClient::ExecuteContext> XrtComputationClient::CreateExecuteOps( const XlaComputation& computation, const std::vector<std::vector<Data*>>& arguments, const Shape* output_shape, std::vector<string>* devices, tensorflow::ClientSession::FeedType* feed_inputs) { ProgramShape program_shape; if (output_shape == nullptr) { program_shape = computation.GetProgramShape().ValueOrDie(); output_shape = &program_shape.result(); } *devices = GetReplicasDevices(arguments); auto xrt_computation = CreateXrtComputation(computation, arguments.size(), *devices, output_shape); absl::optional<tensorflow::ops::Placeholder> computation_holder; std::vector<ExecuteContext> exec_ops; for (size_t i = 0; i < arguments.size(); ++i) { auto inputs = GetArgumentsInputs(arguments[i], devices->at(i), feed_inputs); const string& xrt_device = TorchDeviceToXrtDevice(devices->at(i)); // この部分が、XRT に関する部分になるようです。 SessionData* session = GetSessionForXrtDevice(xrt_device); tensorflow::Scope device_scope = session->root.WithDevice(xrt_device); const CachedNode& cached_node = GetCompileExecuteNode(device_scope, devices->at(i)); // GetCompileExecuteNode メソッドにて、 // XRTCompile Op から XRTExecute Op につながるようになっています。 // ここが、XRT のポイントなんでしょうか?
const XrtComputationClient::CachedNode& XrtComputationClient::GetCompileExecuteNode(const tensorflow::Scope& scope, const string& device) { NodeCache* cache = &node_cache_[NodeCacheKey(device, NodeTypes::kCompileExecute)]; if (cache->empty()) { std::vector<tensorflow::ops::Placeholder> holders( {tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING), tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING), tensorflow::ops::Placeholder( scope, tensorflow::DT_INT64, tensorflow::ops::Placeholder::Shape({-1}))}); auto computation_handle = tensorflow::ops::XRTCompile(scope, holders[0]); std::unique_ptr<CachedNode> node(new CachedNode( tensorflow::ops::XRTExecute(scope, computation_handle.handle, holders[1], {tensorflow::Output(holders[2])}), std::move(holders))); cache->add(std::move(node)); } return cache->get(); }
feed_inputs->insert( {cached_node.holders[0], xrt_computation->SerializeAsString()}); xrt::XRTExecutionConfig exec_config; exec_config.set_core_index_in_replica(0); exec_config.set_release_input_handles(false); exec_config.set_release_compilation_handle(true); feed_inputs->insert( {cached_node.holders[1], exec_config.SerializeAsString()}); feed_inputs->insert({cached_node.holders[2], inputs}); exec_ops.emplace_back(*cached_node.output, *output_shape); } return exec_ops; }
CreateXrtComputation メソッドは、xla::XlaComputation から xrt::XlaComputation を生成します。
std::unique_ptr<xrt::XLAComputation> XrtComputationClient::CreateXrtComputation( const XlaComputation& computation, int64 num_replicas, const std::vector<string>& devices, const Shape* output_shape) const { XLA_CHECK_EQ(num_replicas, devices.size()); std::unique_ptr<xrt::XLAComputation> xrt_computation( new xrt::XLAComputation()); auto config = xrt_computation->mutable_config(); config->set_num_replicas(num_replicas); config->set_num_cores_per_replica(1); if (num_replicas > 1) { auto device_assignment = config->mutable_device_assignment(); auto computation_device = device_assignment->add_computation_devices(); for (int64 i = 0; i < num_replicas; ++i) { const string& xrt_device = TorchDeviceToXrtDevice(devices[i]); const auto& core_coords = GetDeviceMeshCoords(xrt_device); auto replica_device = computation_device->add_replica_devices(); for (auto coord : core_coords) { replica_device->add_value(coord); } } } *config->mutable_program_shape() = computation.GetProgramShape().ValueOrDie(); if (output_shape != nullptr) { *config->mutable_program_shape()->mutable_result() = *output_shape; } *xrt_computation->mutable_hlo_snapshot() = *computation.Snapshot().ValueOrDie(); return xrt_computation; }
CreateExecuteOpsメソッドの戻り値である exec_ops を RunComputationsメソッドに渡します。sess_replica.first->session.Run で実行されます。
std::vector<std::shared_ptr<ComputationClient::Data>> XrtComputationClient::RunComputations( const std::vector<ExecuteContext>& exec_ops, tensorflow::gtl::ArraySlice<const XlaComputation* const> computations, const std::vector<string>& devices, const tensorflow::ClientSession::FeedType& feed_inputs) { // 「PyTorch + XLA」のメインとなる部分だと思います。 // この部分が XLA ではなく、 XRT を使うメリットなのでしょうかね。 // In the PyTorch/XRT interface we keep a map (options_.workers_map) from a // worker+taskno, to the GRPC server which is the entry point for that worker. // Since XRT could re-distribute ops internally, if we have N hosts // (worker+taskno), we could have all the workers pointing to a single GRPC // entry point, or we could have each worker pointing directly to the target // host. // The advantage of the latter approach, is that we do not bottleneck // (especially when feeding inputs) the single GRPC entry point. // Using the N:1 approach, the session_replicas below will contain a single // session, and all the replica executions will go through it (and distributed // by XRT on the service side). // Chosing the 1:1 approach (one session per worker), we will have N sessions // within the session_replicas map, which we will be executing independently. std::map<SessionData*, std::vector<size_t>> session_replicas; for (size_t i = 0; i < devices.size(); ++i) { SessionData* session = GetSessionForDevice(devices[i]); session_replicas[session].push_back(i); } // TODO(dlibenzi): These could be run in parallel. // devices.size() 分、Replicas を作って実行している // この部分がスケールすると、嬉しいということですね。 std::vector<std::shared_ptr<Data>> results(devices.size()); for (auto& sess_replica : session_replicas) { std::vector<tensorflow::Output> exec_nodes; for (auto replica : sess_replica.second) { exec_nodes.push_back(exec_ops[replica].execute_output); } std::vector<tensorflow::Tensor> outputs; TF_CHECK_OK(sess_replica.first->root.status()); xrt_util::CheckComputationStatus( sess_replica.first->session.Run(feed_inputs, exec_nodes, &outputs), computations); // ここで、session.Run を実行する XLA_CHECK_EQ(outputs.size(), exec_nodes.size()); for (size_t i = 0; i < outputs.size(); ++i) { auto replica = sess_replica.second[i]; results[replica] = std::make_shared<XrtData>( devices[replica], outputs[i].scalar<int64>()(), exec_ops[replica].result_shape, [this](XrtData* xrt_data) { ReleaseXrtData(xrt_data); }); } } return results; }
簡単にまとめると、
・XlaComputation の時 (XLA_USE_XRT=0)は、session->session.Run を実行
・XrtComputation の時 (XLA_USE_XRT=1)は、sess_replica.first->session.Run を実行
・XrtComputation の時 (XLA_USE_XRT=1)は、sess_replica.first->session.Run を実行
Clound TPU の場合は、TPUの数分、replica を作って、並列処理を行える!
7)、backward メソッド
テストコード内の test_nll_loss というテストでは、backward メソッドを呼んでいます。
また、XXXGrad(例えば、AvgPoolGrad クラス)と名前のクラスでも同じように、backward メソッドが呼ばれています。
また、XXXGrad(例えば、AvgPoolGrad クラス)と名前のクラスでも同じように、backward メソッドが呼ばれています。
def test_nll_loss(self): input = torch.randn(3, 5, requires_grad=True) target = torch.empty(3, dtype=torch.long).random_(5) model = XlaNllLoss() traced_model = torch.jit.trace(model, (input, target)) xla_model = torch_xla._C.XlaModule(traced_model) xla_inputs = [torch_xla._C.XLATensor(input), torch_xla._C.XLATensor(target)] output_xla = xla_model((tuple(xla_inputs))) xla_model.backward(*output_xla) output = model(input, target) output.backward() self.assertEqual(input.grad.data, xla_inputs[0].grad.data.to_tensor())
forward メソッドについては、これまでに説明してきましたが、backward メソッドについては、この章で説明します。backward メソッドは、forward メソッドに比べて非常に長いです。backward メソッドは、InitXlaModuleBindingsメソッドの中のXlaModuleクラスのメソッド ( backward ) として定義されています。
.def("backward", [](XlaModule& xla_module, py::args args) { auto inputs = XlaCreateTensorList(args); xla_module.backward(inputs); })
backwardメソッドでは、C++のXlaModuleの backward メソッドを呼んでいます。
forward側のXlaComputation (forward_computation)の生成は、 グラフ (f_) をベースに行います。
XlaTranslator xla_fwd_impl(f_, GetPrecisionConfig()); forward_computation_ = xla_fwd_impl.BuildComputation(forward_shapes);
一方、backward側のXlaComputation (back_computation)の生成は、 グラフ (df_) をベースに行います。
XlaTranslator xla_bwd_impl(df_, GetPrecisionConfig()); backward_computation_ = xla_bwd_impl.BuildComputation( backward_shapes, GetBackwardBuildOptions(0, inputs_.size()));
forward と backward では違うグラフ ( f_ と df_ )にて、XlaComputationを生成しています。では、このグラフ ( f_ と df_ )はどのように求めたのでしょうか? グラフ ( f_ と df_ )は、XlaModule::Initialize メソッドにて、次のように設定されています。
void XlaModule::Initialize(const TensorBatchVector& inputs) { if (script_module_ == nullptr) { return; } // Get forward graph. const auto forward = script_module_->find_method("forward"); // どうやら、”forward”メソッドを見つけているようだ JIT_ASSERT(forward); std::shared_ptr<Graph> forward_graph = forward->graph()->copy(); auto forward_graph_copy = forward_graph->copy(); // 見つけたものからグラフをゲットする Gradient gradient = differentiate(forward_graph_copy); // その微分を求める // 途中略 // Take ownership of the forward and differentiated graphs and release the // reference to the script module to mark initialization as done. f_ = gradient.f; df_ = gradient.df; // 求めた微分からグラフ ( f と df ) を獲得している // Mark the module as initialized. script_module_ = nullptr; }
forwardメソッドを見つけるためのscript_module_ は、コンストラクタにて module が設定されています。
XlaModule::XlaModule(const std::shared_ptr<script::Module> module, bool use_full_conv_precision, bool differentiate) : use_full_conv_precision_(use_full_conv_precision), enable_trace_fusion_(differentiate), differentiate_(differentiate), module_id_(s_module_id_++), script_module_(module) {}
テストコードの torch_xla._C.XlaModule の引数である traced_model が script_module_ になります。
class TestMulAdd(XlaTestCase): def test(self): class XlaMulAdd(nn.Module): def forward(self, x, y): return x * y + y x = torch.rand(3, 5) y = torch.rand(3, 5) model = XlaMulAdd() traced_model = torch.jit.trace(model, (x, y)) xla_model = torch_xla._C.XlaModule(traced_model) inputs_xla = [torch_xla._C.XLATensor(x), torch_xla._C.XLATensor(y)] output_xla = xla_model((tuple(inputs_xla))) expected = model(x, y) self.assertEqualDbg(output_xla[0][0].to_tensor().data, expected.data)
const auto forward = script_module_->find_method("forward"
は、上記のテストコードの TestMulAdd クラスでは、XlaMulAdd クラスの forward メソッドになります。std::shared_ptr<Graph> forward_graph = forward->graph()->copy(); auto forward_graph_copy = forward_graph->copy(); Gradient gradient = differentiate(forward_graph_copy);
は、foward メソッドのグラフから diffrerentiate メソッドにて、Gradient を求めていることになります。
differentiate メソッドは、PyTorchの autodiff.cpp の中で次のように定義されています。
Gradient differentiate(std::shared_ptr<Graph>& graph) { Gradient grad_desc; // Take ownership of the graph JIT_ASSERTM(graph.use_count() == 1, "differentiate will mutate and destroy the graph, so it requires " "graph.use_count() == 1, but found %d", graph.use_count()); std::swap(graph, grad_desc.f); // XXX: Take care when handling outputs - they can be duplicated! WithInsertPoint guard(grad_desc.f->block()); // Fills in df_input_vjps and df_output_vjps auto rev_info = addReverseInline(grad_desc); // Lift constants captured for the reverse graph into it liftConstants(grad_desc, rev_info); // addReverseInline has to call gradientForNode if *any* of the outputs // require grad, but it will emit vjps for *all* outputs. Use DCE to remove // unnecessary nodes. EliminateDeadCode(rev_info.reverse_block); // Fills in f, df, f_real_outputs, df_input_captures, // modifies df_input_vjps (new vjps are added for temporaries) lambdaLiftReverse(grad_desc, rev_info); return grad_desc; }