Vengineerの妄想

人生を妄想しています。

LOADING A PYTORCH MODEL IN C++


「PyTorch + XLA」のソースコード解析をして、
PyTorchで C++ のモデルの扱いについて、調べてみたら、おもろかったので記録として残しておきます。


「PyTorch + XLA」のテストコード(test/test_operatopns.py) の TestMulAdd クラスでは、次のように、PyTorchのモデルを torch.jit.trace メソッドにて何かしている?

class TestMulAdd(XlaTestCase):

  def test(self):

    class XlaMulAdd(nn.Module):

      def forward(self, x, y):
        return x * y + y

    x = torch.rand(3, 5)
    y = torch.rand(3, 5)
    model = XlaMulAdd()
    traced_model = torch.jit.trace(model, (x, y))
    xla_model = torch_xla._XLAC.XlaModule(traced_model)
    inputs_xla = [torch_xla._XLAC.XLATensor(x), torch_xla._XLAC.XLATensor(y)]
    output_xla = xla_model((tuple(inputs_xla)))
    expected = model(x, y)
    self.assertEqualDbg(output_xla[0][0].to_tensor().data, expected.data)

この torch.jit.trace って、いったい何をやっているのだろうか?

Google君に聞いたら、出てきました。。


ここに、ちゃんと説明されていました。こんな感じに。

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

ここでもモデルは、import torchvision にて、torchvision.models.resnet18() を使っています。
このモデル (model) と入力データ (example = torch.rand(1, 3, 244, 244) を torch.jit.trace に渡して、
戻ってきたのが、traced_script_module で、このモデルが torch.jit.ScriptModule になるようです。

    output = traced_script_module(torch.ones(1, 3, 224, 224))

のように、入力データを指定して、その結果が戻ってくると。。。

PyTorchでは、下記のように、torch.nn.Module クラスを継承してモデルを構築するようですね。
__init__ メソッドと、forward メソッドを定義すればいいようです。
__init__ メソッド内では、モデルに必要な準備をして、forward メソッド内では入力データからの処理を行っています。

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

このモデルを上記で説明した、torch.jit.ScriptModule に変えたものが下記のコードです。
torch.nn.Module から torch.jit.ScriptModule に変更して、
forward メソッドには @torch.jit.script_method というデコレータを追加しています。

import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    @torch.jit.script_method
    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_script_module = MyModule()

torch.jit.ScriptModule は、ファイルにシリアライズすることができます。

    traced_script_module = MyModule()

    traced_script_module.save("model.pt")

model.pt というファイルに出力されます。このファイルを再度ロードすることができます。
C++コードでのシリアライズされたファイルのロード方法は次のようなコードでできるようです。
torch::jit::load関数にて、ファイル名を指定するだけでいいです。
戻り値は、std::shared_ptr<torch::jit::script::Module> になります。
C++コードでは、torch::jit::script::Module が Python の torch.jit.ScriptModule に対応するようですね。

#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }

  // Deserialize the ScriptModule from a file using torch::jit::load().
  std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);

  assert(module != nullptr);
  std::cout << "ok\n";
}

ロードしたモデルを C++コードで実行するコードもあります。
モデルに対して、torch::jit::IValue として、(1, 3, 224, 244) の要素がすべて1のデータを module->forward の
引数として渡して、その戻り値をtoTensor() で変換したものが最終的な戻り値 (at::Tensor) になっています。

// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));

// Execute the model and turn its output into a tensor.
at::Tensor output = module->forward(inputs).toTensor();

std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

「PyTorch + XLA」では、ここで、script::Module が使われています。

void InitXlaModuleBindings(py::module m) {
  py::class_<XlaModule, std::shared_ptr<XlaModule>>(m, "XlaModule")
      .def(py::init([](const std::shared_ptr<script::Module> module,
                       bool use_full_conv_precision, bool differentiate) {
             return std::make_shared<XlaModule>(module, use_full_conv_precision,
                                                differentiate);
           }),
           py::arg("module"), py::arg("use_full_conv_precision") = false,
           py::arg("differentiate") = true)

script::Modulle となっているのは、その前に、

namespace torch {
namespace jit {

namespace {

になっていますので、torch::jit::script::Module になっているわけですね。