Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

「PyTorch + XLA」のソースコード解析 (その2)



12月1日に、Qiitaの「TensorFlow Advent Calendar 2018」の初日(12/1)として、アップした「PyTorch + XLA」のソースコード解析ですが、みなさんへの「クリスマスプレゼント」として、残りの部分を今日(12/24)は、3章、4章、5章と、明日(12/25)は6章と7章を公開します。

今回公開するメモは、「PyTorch + XLA」が公開された時のコードです。
「PyTorch + XLA」はその後もアップデートされていて、コードは変わっていますので、注意してくださいね。

=======================================================================================

3)、PyTorchからXLAのOpへの変換

BuildComputationProgram メソッドの内で、グラフ内のノードに対して、変換を行っています。

残念ながら、下記の case分には無いタイプの Op は変換できません。
ResNet-50 を学習できるレベルのOpはサポートしているようです。
他の Op も変換したい場合は、ひたすら変換したい Op 用のコードを書きまくるしかないのだろうか?

 auto nodes = graph_->block()->nodes();
  for (auto node : nodes) {
    switch (node->kind()) {
      case aten::add:
      case aten::mul: {...}
      case aten::gt: {...}
      case aten::type_as: {...}
      case aten::convolution:
      case aten::thnn_conv2d_forward: {...}
      case aten::thnn_conv2d_backward: {...}
      case aten::t: {...}
      case aten::addmm: {...}
      case aten::mm: {...}
      case aten::max_pool2d_with_indices: {...}
      case aten::max_pool2d_with_indices_backward: {...}
      case aten::avg_pool2d: {...}
      case aten::avg_pool2d_backward: {...}
      case aten::neg: {...}
      case aten::tanh: {...}
      case aten::sigmoid: {...}
      case aten::relu: {...}
      case aten::threshold: {...}
      case aten::threshold_backward: {...}
      case aten::log_softmax: {...}
      case aten::_log_softmax_backward_data: {...}
      case aten::reshape:
      case aten::view: {...}
      case aten::expand: {...}
      case aten::stack: {...}
      case aten::cat: {...}
      case aten::chunk: {...}
      case aten::native_batch_norm:
      case aten::batch_norm: {...}
      case aten::native_batch_norm_backward: {...}
      case aten::sum: {...}
      case aten::nll_loss: {...}
      case aten::nll_loss_backward: {...}
      case prim::Constant: {...}
      case prim::ListConstruct: {...}
      case prim::Undefined: {...}


具体的な例として、一番簡単そうな add/mul を見てみましょう。

      case aten::add:
      case aten::mul: {
        const auto node_inputs = node->inputs();
        if (node_inputs.size() < 2) {
          AT_ERROR("Unsupported arity for binary operator ",
                   node->kind().toQualString());
        }
        // ノードの入力を獲得して、入力が2未満の場合はエラーになります
        auto input_op_1_optional = cctx.GetOpForInput(node, 1);
        xla::XlaOp input_op_1;
        if (!input_op_1_optional) {
          input_op_1 = XlaHelpers::ScalarValue(
              node->get<at::Scalar>(attr::other).value().to<float>(), b);
        } else {
          input_op_1 = *input_op_1_optional;
        }
        auto inputs =
            XlaHelpers::PromoteValues(cctx.OpForInput(node, 0), input_op_1);
        xla::XlaOp xla_output =
            BuildArithmeticOp(node, inputs.first, inputs.second);
     // BuildArithmeticOpにて、nodeのタイプによって、Add/Mulを区別しています。
        // BuildAddArithmeticOpは、xla/torch_xla/csrc/elementwise.cpp で
        // 定義されています。
        cctx.AddNodeOp(node, xla_output);
        // AddNodeOpにて、node に、xla_output(XlaOp) を追加します   
        // void AddNodeOp(const Node* node, xla::XlaOp op) {
        //   AddNodeOpById(OutputId(node), op);
        // } 
      // void AddNodeOpById(size_t id, xla::XlaOp op) {
        //   const auto it_ok = node_xla_ops_.emplace(id, std::move(op));
        //   XLA_CHECK(it_ok.second) << "Duplicated IR node ID: " << id;
        // }
        break;
      }

BuildArithmeticOp メソッドでは、add と mul のケースに分岐して処理をしています。

xla::XlaOp BuildArithmeticOp(const Node* node, const xla::XlaOp& lhs,
                             const xla::XlaOp& rhs) {
  // node->kind() にて、add と mul を区別しています。
  switch (node->kind()) {
    case aten::add: {
      return DoBinaryOpWithImplicitExpand(
          lhs, rhs, [](const xla::XlaOp& lhs, const xla::XlaOp& rhs) {
            return lhs + rhs;
          });
    }
    case aten::mul: {
      return DoBinaryOpWithImplicitExpand(
          lhs, rhs, [](const xla::XlaOp& lhs, const xla::XlaOp& rhs) {
            return lhs * rhs;
          });
    }
    default:
      LOG(FATAL) << "Invalid binary operator kind: " << node->kind();
  }
}

xla::XlaOp DoBinaryOpWithImplicitExpand(
    const xla::XlaOp& lhs, const xla::XlaOp& rhs,
    std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&)> op_fun) {
  return op_fun(BuildImplicitExpand(lhs, rhs), BuildImplicitExpand(rhs, lhs));
}

BuildImplicitExpand メソッドは、2つの引数のShapeSizesを求め、大きい方に拡張します。

xla::XlaOp BuildImplicitExpand(const xla::XlaOp& input,
                               const xla::XlaOp& output) {
  const auto output_sizes =
      XlaHelpers::ShapeSizes(XlaHelpers::ShapeOfXlaOp(output));
  const auto input_sizes =
      XlaHelpers::ShapeSizes(XlaHelpers::ShapeOfXlaOp(input));
  if (input_sizes.size() >= output_sizes.size() || input_sizes.empty()) {
    return input;
  }
  return BuildExpandToOutputSizes(input, output_sizes);
}

各ノードを XlaOp に変換しているのは、xla/torch_xla/csrc/ ディレクトリにある以下のファイルです。
各ファイルの中で、xla::XXXXX に変換されます。

  ・batch_norm			BatchNormTraining, GetTupleElement, Pow
  ・pooling			MaxPooling, SelectAndScatterWithGeneralPadding, AvgPool, AvgPoolGrad
  ・convolution			Transpose, Rev, Pad, ConvGeneralDilated, Broadcast, Reduce
  ・elementwise 		Gt, ConvertElementType, Broadcast, ConvertElementType
  ・reduction			Reduce
  ・log_softmax			Reduce, Sub, Mul, Log
  ・null_loss			Select, ReduceAll, Neg
  ・cross_replica_reduces	CrossReplicaSum, Broadcast


4)、一旦、振り返り

テストコード(TestMulAdd)がどのように実行されるかを振り返ったのが、次のようになります。

class TestMulAdd(XlaTestCase):
    def test(self):

        class XlaMulAdd(nn.Module):
            def forward(self, x, y):
                return x * y + y

        x = torch.rand(3, 5)
        y = torch.rand(3, 5)
        model = XlaMulAdd()
        traced_model = torch.jit.trace(model, (x, y))
        xla_model = torch_xla._C.XlaModule(traced_model)
        inputs_xla = [torch_xla._C.XLATensor(x), torch_xla._C.XLATensor(y)]
        output_xla = xla_model((tuple(inputs_xla)))
          =>  .def("__call__",
               [](XlaModule& xla_module, py::args args) -> py::object {
                   =>  XlaModule::TensorBatchVector XlaModule::forward(
                                  const TensorBatchVector& inputs) {
                         Initialize(inputs);
                         if (enable_trace_fusion_) {
                           const auto return_node = df_->return_node();
                           const auto node_inputs = return_node->inputs();
                           if (!node_inputs.empty()) {
                             return RunFusedTrain(inputs);
                               => Execute
                     => ExecuteComputation
                                   => XlaTranslator xla_fwd_impl 
                                     xla_fwd_impl.BuildComputation
                                     => BuildComputationProgram
                                        ここで、ノードを XlaOp に変換する
                           }
                         }
                         return RunUnfusedForward(inputs);
                           => Execute 
                   => ExecuteComputation
                               => XlaTranslator xla_fwd_impl
                                 xla_fwd_impl.BuildComputation(
                                 => BuildComputationProgram
                      ここで、ノードを XlaOp に変換する
        expected = model(x, y)
        self.assertEqualDbg(output_xla[0][0].to_tensor().data, expected.data)



5)、実行環境の決定

XlaModule::Executeメソッドの中に出てきた CreateClient メソッドは、下記のようなコードです。

     void CreateClient(std::unique_ptr<xla::ComputationClient>* client) {
       *client = std::move(xla::ComputationClient::Create().ValueOrDie());
     }

 xla::ComputationClient::Create() メソッドは、次のようになっています。

StatusOr<std::unique_ptr<ComputationClient>> ComputationClient::Create() {
  std::unique_ptr<ComputationClient> client;
  string xrt_config_path;
 // XrtClientを使うかどうかをチェック
  if (ShouldUseXrtClient(&xrt_config_path)) {

bool ShouldUseXrtClient(string* config_path) {
  *config_path = GetTpuClusterConfigPath();
  if (access(config_path->c_str(), F_OK) != -1) {
    // If we have a TPU cluster config file, we are in Cloud TPU world, so steer
    // towards config file based XRT client.
    // config fileがある場合
    return true;
  }
  config_path->clear();
  // 環境変数 XLA_USE_XRT が 1 以上の場合
  // export XLA_USE_XRT=0 export XLA_GRPC_HOST="" XLA_PLATFORM="CPU"
  // の場合は、XLA_USE_XRT=0 なので、false になる。
  // export XLA_USE_XRT=1 \
  //  XRT_DEVICE_MAP="CPU:0;/job:localhost/replica:0/task:0/device:XLA_CPU:0" \   
  //  XRT_WORKERS="localhost:0;"
  // の場合は、XLA_USE_XRT=1 なので、true になる。
  return GetEnvInt("XLA_USE_XRT", -1) > 0;
}

    // XrtComputationClient を利用する
    XrtComputationClient::Options options;
    if (!xrt_config_path.empty()) {
      LOG(INFO) << "Loading XRT configuration from " << xrt_config_path;
      TF_RETURN_IF_ERROR(ParseTpuClusterConfig(xrt_config_path, &options));
    } else {
      TF_ASSIGN_OR_RETURN(bool configured,
                          ParseEnvBasedTpuClusterConfig(&options));

StatusOr<bool> ParseEnvBasedTpuClusterConfig(
    XrtComputationClient::Options* options) {
  string tpu_config = GetEnvString("XRT_TPU_CONFIG", "");
  if (tpu_config.empty()) {
    return false;
    // 環境変数 XRT_TPU_CONFIG が設定されていなければ、false を返す
  }
  std::map<string, int> device_ordinals;
  std::vector<string> spec_parts = absl::StrSplit(tpu_config, '|');
  TF_RET_CHECK(!spec_parts.empty()) << tpu_config;
  for (const auto& spec : spec_parts) {
    std::vector<string> host_parts = absl::StrSplit(spec, ';');
    TF_RET_CHECK(host_parts.size() == 3) << spec;
    AddXrtHostDevices(host_parts[0], std::stoi(host_parts[1]), host_parts[2],
                      &device_ordinals, options);
  }
  options->default_device = "TPU:0";
  return true;
}

      if (!configured) {
      // 環境変数 XRT_TPU_CONFIG が設定されていない
        string device_spec =
            GetEnvString("XRT_DEVICE_MAP",
                         "TPU:0;/job:tpu_worker/replica:0/task:0/device:TPU:0");
        // 環境変数 XRT_DEVICEMAP を解析する
        for (const auto& device_target : absl::StrSplit(device_spec, '|')) {
          std::vector<string> parts = absl::StrSplit(device_target, ';');
          TF_RET_CHECK(parts.size() == 2) << device_target;
          if (options.default_device.empty()) {
            options.default_device = parts[0];
          }
          options.device_map.emplace(parts[0], parts[1]);
        }
        string workers_spec =
            GetEnvString("XRT_WORKERS", "tpu_worker:0;grpc://localhost:51000");
        // 環境変数 XRT_WORKERS を解析する
        for (const auto& name_target : absl::StrSplit(workers_spec, '|')) {
          std::vector<string> parts = absl::StrSplit(name_target, ';');
          TF_RET_CHECK(parts.size() == 2);
          options.workers_map.emplace(ParseWorker(parts[0]), parts[1]);
        }
      // export XLA_USE_XRT=1 \
      // XRT_DEVICE_MAP="TPU:0;/job:tpu_worker/replica:0/task:0/device:TPU:0"\
      // XRT_WORKERS="tpu_worker:0;grpc://localhost:51000"
      }
    }
    client.reset(new XrtComputationClient(options));
  } else {
    XlaComputationClient::Options options;
    options.host_name = GetEnvString("XLA_GRPC_HOST", "localhost");
    options.port = GetEnvInt("XLA_GRPC_PORT", 51000);
    options.platform = GetEnvString("XLA_PLATFORM", "TPU");
    client.reset(new XlaComputationClient(options));
  }
  return std::move(client);