12月1日に、Qiitaの「TensorFlow Advent Calendar 2018」の初日(12/1)として、アップした「PyTorch + XLA」のソースコード解析ですが、みなさんへの「クリスマスプレゼント」として、残りの部分を今日(12/24)は、3章、4章、5章と、明日(12/25)は6章と7章を公開します。
今回公開するメモは、「PyTorch + XLA」が公開された時のコードです。
「PyTorch + XLA」はその後もアップデートされていて、コードは変わっていますので、注意してくださいね。
「PyTorch + XLA」はその後もアップデートされていて、コードは変わっていますので、注意してくださいね。
=======================================================================================
3)、PyTorchからXLAのOpへの変換
BuildComputationProgram メソッドの内で、グラフ内のノードに対して、変換を行っています。
残念ながら、下記の case分には無いタイプの Op は変換できません。
ResNet-50 を学習できるレベルのOpはサポートしているようです。
他の Op も変換したい場合は、ひたすら変換したい Op 用のコードを書きまくるしかないのだろうか?
ResNet-50 を学習できるレベルのOpはサポートしているようです。
他の Op も変換したい場合は、ひたすら変換したい Op 用のコードを書きまくるしかないのだろうか?
auto nodes = graph_->block()->nodes(); for (auto node : nodes) { switch (node->kind()) { case aten::add: case aten::mul: {...} case aten::gt: {...} case aten::type_as: {...} case aten::convolution: case aten::thnn_conv2d_forward: {...} case aten::thnn_conv2d_backward: {...} case aten::t: {...} case aten::addmm: {...} case aten::mm: {...} case aten::max_pool2d_with_indices: {...} case aten::max_pool2d_with_indices_backward: {...} case aten::avg_pool2d: {...} case aten::avg_pool2d_backward: {...} case aten::neg: {...} case aten::tanh: {...} case aten::sigmoid: {...} case aten::relu: {...} case aten::threshold: {...} case aten::threshold_backward: {...} case aten::log_softmax: {...} case aten::_log_softmax_backward_data: {...} case aten::reshape: case aten::view: {...} case aten::expand: {...} case aten::stack: {...} case aten::cat: {...} case aten::chunk: {...} case aten::native_batch_norm: case aten::batch_norm: {...} case aten::native_batch_norm_backward: {...} case aten::sum: {...} case aten::nll_loss: {...} case aten::nll_loss_backward: {...} case prim::Constant: {...} case prim::ListConstruct: {...} case prim::Undefined: {...}
具体的な例として、一番簡単そうな add/mul を見てみましょう。
case aten::add: case aten::mul: { const auto node_inputs = node->inputs(); if (node_inputs.size() < 2) { AT_ERROR("Unsupported arity for binary operator ", node->kind().toQualString()); } // ノードの入力を獲得して、入力が2未満の場合はエラーになります auto input_op_1_optional = cctx.GetOpForInput(node, 1); xla::XlaOp input_op_1; if (!input_op_1_optional) { input_op_1 = XlaHelpers::ScalarValue( node->get<at::Scalar>(attr::other).value().to<float>(), b); } else { input_op_1 = *input_op_1_optional; } auto inputs = XlaHelpers::PromoteValues(cctx.OpForInput(node, 0), input_op_1); xla::XlaOp xla_output = BuildArithmeticOp(node, inputs.first, inputs.second); // BuildArithmeticOpにて、nodeのタイプによって、Add/Mulを区別しています。 // BuildAddArithmeticOpは、xla/torch_xla/csrc/elementwise.cpp で // 定義されています。 cctx.AddNodeOp(node, xla_output); // AddNodeOpにて、node に、xla_output(XlaOp) を追加します // void AddNodeOp(const Node* node, xla::XlaOp op) { // AddNodeOpById(OutputId(node), op); // } // void AddNodeOpById(size_t id, xla::XlaOp op) { // const auto it_ok = node_xla_ops_.emplace(id, std::move(op)); // XLA_CHECK(it_ok.second) << "Duplicated IR node ID: " << id; // } break; }
BuildArithmeticOp メソッドでは、add と mul のケースに分岐して処理をしています。
xla::XlaOp BuildArithmeticOp(const Node* node, const xla::XlaOp& lhs, const xla::XlaOp& rhs) { // node->kind() にて、add と mul を区別しています。 switch (node->kind()) { case aten::add: { return DoBinaryOpWithImplicitExpand( lhs, rhs, [](const xla::XlaOp& lhs, const xla::XlaOp& rhs) { return lhs + rhs; }); } case aten::mul: { return DoBinaryOpWithImplicitExpand( lhs, rhs, [](const xla::XlaOp& lhs, const xla::XlaOp& rhs) { return lhs * rhs; }); } default: LOG(FATAL) << "Invalid binary operator kind: " << node->kind(); } } xla::XlaOp DoBinaryOpWithImplicitExpand( const xla::XlaOp& lhs, const xla::XlaOp& rhs, std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&)> op_fun) { return op_fun(BuildImplicitExpand(lhs, rhs), BuildImplicitExpand(rhs, lhs)); }
BuildImplicitExpand メソッドは、2つの引数のShapeSizesを求め、大きい方に拡張します。
xla::XlaOp BuildImplicitExpand(const xla::XlaOp& input, const xla::XlaOp& output) { const auto output_sizes = XlaHelpers::ShapeSizes(XlaHelpers::ShapeOfXlaOp(output)); const auto input_sizes = XlaHelpers::ShapeSizes(XlaHelpers::ShapeOfXlaOp(input)); if (input_sizes.size() >= output_sizes.size() || input_sizes.empty()) { return input; } return BuildExpandToOutputSizes(input, output_sizes); }
・batch_norm BatchNormTraining, GetTupleElement, Pow ・pooling MaxPooling, SelectAndScatterWithGeneralPadding, AvgPool, AvgPoolGrad ・convolution Transpose, Rev, Pad, ConvGeneralDilated, Broadcast, Reduce ・elementwise Gt, ConvertElementType, Broadcast, ConvertElementType ・reduction Reduce ・log_softmax Reduce, Sub, Mul, Log ・null_loss Select, ReduceAll, Neg ・cross_replica_reduces CrossReplicaSum, Broadcast
4)、一旦、振り返り
テストコード(TestMulAdd)がどのように実行されるかを振り返ったのが、次のようになります。
class TestMulAdd(XlaTestCase): def test(self): class XlaMulAdd(nn.Module): def forward(self, x, y): return x * y + y x = torch.rand(3, 5) y = torch.rand(3, 5) model = XlaMulAdd() traced_model = torch.jit.trace(model, (x, y)) xla_model = torch_xla._C.XlaModule(traced_model) inputs_xla = [torch_xla._C.XLATensor(x), torch_xla._C.XLATensor(y)] output_xla = xla_model((tuple(inputs_xla))) => .def("__call__", [](XlaModule& xla_module, py::args args) -> py::object { => XlaModule::TensorBatchVector XlaModule::forward( const TensorBatchVector& inputs) { Initialize(inputs); if (enable_trace_fusion_) { const auto return_node = df_->return_node(); const auto node_inputs = return_node->inputs(); if (!node_inputs.empty()) { return RunFusedTrain(inputs); => Execute => ExecuteComputation => XlaTranslator xla_fwd_impl xla_fwd_impl.BuildComputation => BuildComputationProgram ここで、ノードを XlaOp に変換する } } return RunUnfusedForward(inputs); => Execute => ExecuteComputation => XlaTranslator xla_fwd_impl xla_fwd_impl.BuildComputation( => BuildComputationProgram ここで、ノードを XlaOp に変換する expected = model(x, y) self.assertEqualDbg(output_xla[0][0].to_tensor().data, expected.data)
5)、実行環境の決定
XlaModule::Executeメソッドの中に出てきた CreateClient メソッドは、下記のようなコードです。
void CreateClient(std::unique_ptr<xla::ComputationClient>* client) { *client = std::move(xla::ComputationClient::Create().ValueOrDie()); } xla::ComputationClient::Create() メソッドは、次のようになっています。 StatusOr<std::unique_ptr<ComputationClient>> ComputationClient::Create() { std::unique_ptr<ComputationClient> client; string xrt_config_path; // XrtClientを使うかどうかをチェック if (ShouldUseXrtClient(&xrt_config_path)) {
bool ShouldUseXrtClient(string* config_path) { *config_path = GetTpuClusterConfigPath(); if (access(config_path->c_str(), F_OK) != -1) { // If we have a TPU cluster config file, we are in Cloud TPU world, so steer // towards config file based XRT client. // config fileがある場合 return true; } config_path->clear(); // 環境変数 XLA_USE_XRT が 1 以上の場合 // export XLA_USE_XRT=0 export XLA_GRPC_HOST="" XLA_PLATFORM="CPU" // の場合は、XLA_USE_XRT=0 なので、false になる。 // export XLA_USE_XRT=1 \ // XRT_DEVICE_MAP="CPU:0;/job:localhost/replica:0/task:0/device:XLA_CPU:0" \ // XRT_WORKERS="localhost:0;" // の場合は、XLA_USE_XRT=1 なので、true になる。 return GetEnvInt("XLA_USE_XRT", -1) > 0; }
// XrtComputationClient を利用する XrtComputationClient::Options options; if (!xrt_config_path.empty()) { LOG(INFO) << "Loading XRT configuration from " << xrt_config_path; TF_RETURN_IF_ERROR(ParseTpuClusterConfig(xrt_config_path, &options)); } else { TF_ASSIGN_OR_RETURN(bool configured, ParseEnvBasedTpuClusterConfig(&options));
StatusOr<bool> ParseEnvBasedTpuClusterConfig( XrtComputationClient::Options* options) { string tpu_config = GetEnvString("XRT_TPU_CONFIG", ""); if (tpu_config.empty()) { return false; // 環境変数 XRT_TPU_CONFIG が設定されていなければ、false を返す } std::map<string, int> device_ordinals; std::vector<string> spec_parts = absl::StrSplit(tpu_config, '|'); TF_RET_CHECK(!spec_parts.empty()) << tpu_config; for (const auto& spec : spec_parts) { std::vector<string> host_parts = absl::StrSplit(spec, ';'); TF_RET_CHECK(host_parts.size() == 3) << spec; AddXrtHostDevices(host_parts[0], std::stoi(host_parts[1]), host_parts[2], &device_ordinals, options); } options->default_device = "TPU:0"; return true; }
if (!configured) { // 環境変数 XRT_TPU_CONFIG が設定されていない string device_spec = GetEnvString("XRT_DEVICE_MAP", "TPU:0;/job:tpu_worker/replica:0/task:0/device:TPU:0"); // 環境変数 XRT_DEVICEMAP を解析する for (const auto& device_target : absl::StrSplit(device_spec, '|')) { std::vector<string> parts = absl::StrSplit(device_target, ';'); TF_RET_CHECK(parts.size() == 2) << device_target; if (options.default_device.empty()) { options.default_device = parts[0]; } options.device_map.emplace(parts[0], parts[1]); } string workers_spec = GetEnvString("XRT_WORKERS", "tpu_worker:0;grpc://localhost:51000"); // 環境変数 XRT_WORKERS を解析する for (const auto& name_target : absl::StrSplit(workers_spec, '|')) { std::vector<string> parts = absl::StrSplit(name_target, ';'); TF_RET_CHECK(parts.size() == 2); options.workers_map.emplace(ParseWorker(parts[0]), parts[1]); } // export XLA_USE_XRT=1 \ // XRT_DEVICE_MAP="TPU:0;/job:tpu_worker/replica:0/task:0/device:TPU:0"\ // XRT_WORKERS="tpu_worker:0;grpc://localhost:51000" } } client.reset(new XrtComputationClient(options)); } else { XlaComputationClient::Options options; options.host_name = GetEnvString("XLA_GRPC_HOST", "localhost"); options.port = GetEnvInt("XLA_GRPC_PORT", 51000); options.platform = GetEnvString("XLA_PLATFORM", "TPU"); client.reset(new XlaComputationClient(options)); } return std::move(client);