「PyTorch + XLA」のテストコード(test/test_operatopns.py) の TestMulAdd クラスでは、次のように、PyTorchのモデルを torch.jit.trace メソッドにて何かしている?
class TestMulAdd(XlaTestCase): def test(self): class XlaMulAdd(nn.Module): def forward(self, x, y): return x * y + y x = torch.rand(3, 5) y = torch.rand(3, 5) model = XlaMulAdd() traced_model = torch.jit.trace(model, (x, y)) xla_model = torch_xla._XLAC.XlaModule(traced_model) inputs_xla = [torch_xla._XLAC.XLATensor(x), torch_xla._XLAC.XLATensor(y)] output_xla = xla_model((tuple(inputs_xla))) expected = model(x, y) self.assertEqualDbg(output_xla[0][0].to_tensor().data, expected.data)
この torch.jit.trace って、いったい何をやっているのだろうか?
Google君に聞いたら、出てきました。。
ここに、ちゃんと説明されていました。こんな感じに。
import torch import torchvision # An instance of your model. model = torchvision.models.resnet18() # An example input you would normally provide to your model's forward() method. example = torch.rand(1, 3, 224, 224) # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. traced_script_module = torch.jit.trace(model, example)
ここでもモデルは、import torchvision にて、torchvision.models.resnet18() を使っています。
このモデル (model) と入力データ (example = torch.rand(1, 3, 244, 244) を torch.jit.trace に渡して、
戻ってきたのが、traced_script_module で、このモデルが torch.jit.ScriptModule になるようです。
このモデル (model) と入力データ (example = torch.rand(1, 3, 244, 244) を torch.jit.trace に渡して、
戻ってきたのが、traced_script_module で、このモデルが torch.jit.ScriptModule になるようです。
output = traced_script_module(torch.ones(1, 3, 224, 224))
のように、入力データを指定して、その結果が戻ってくると。。。
PyTorchでは、下記のように、torch.nn.Module クラスを継承してモデルを構築するようですね。
__init__ メソッドと、forward メソッドを定義すればいいようです。
__init__ メソッド内では、モデルに必要な準備をして、forward メソッド内では入力データからの処理を行っています。
__init__ メソッドと、forward メソッドを定義すればいいようです。
__init__ メソッド内では、モデルに必要な準備をして、forward メソッド内では入力データからの処理を行っています。
import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() self.weight = torch.nn.Parameter(torch.rand(N, M)) def forward(self, input): if input.sum() > 0: output = self.weight.mv(input) else: output = self.weight + input return output
このモデルを上記で説明した、torch.jit.ScriptModule に変えたものが下記のコードです。
torch.nn.Module から torch.jit.ScriptModule に変更して、
forward メソッドには @torch.jit.script_method というデコレータを追加しています。
torch.nn.Module から torch.jit.ScriptModule に変更して、
forward メソッドには @torch.jit.script_method というデコレータを追加しています。
import torch class MyModule(torch.jit.ScriptModule): def __init__(self, N, M): super(MyModule, self).__init__() self.weight = torch.nn.Parameter(torch.rand(N, M)) @torch.jit.script_method def forward(self, input): if input.sum() > 0: output = self.weight.mv(input) else: output = self.weight + input return output my_script_module = MyModule()
traced_script_module = MyModule() traced_script_module.save("model.pt")
model.pt というファイルに出力されます。このファイルを再度ロードすることができます。
C++コードでのシリアライズされたファイルのロード方法は次のようなコードでできるようです。
torch::jit::load関数にて、ファイル名を指定するだけでいいです。
戻り値は、std::shared_ptr<torch::jit::script::Module> になります。
C++コードでは、torch::jit::script::Module が Python の torch.jit.ScriptModule に対応するようですね。
C++コードでのシリアライズされたファイルのロード方法は次のようなコードでできるようです。
torch::jit::load関数にて、ファイル名を指定するだけでいいです。
戻り値は、std::shared_ptr<torch::jit::script::Module> になります。
C++コードでは、torch::jit::script::Module が Python の torch.jit.ScriptModule に対応するようですね。
#include <torch/script.h> // One-stop header. #include <iostream> #include <memory> int main(int argc, const char* argv[]) { if (argc != 2) { std::cerr << "usage: example-app <path-to-exported-script-module>\n"; return -1; } // Deserialize the ScriptModule from a file using torch::jit::load(). std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]); assert(module != nullptr); std::cout << "ok\n"; }
ロードしたモデルを C++コードで実行するコードもあります。
モデルに対して、torch::jit::IValue として、(1, 3, 224, 244) の要素がすべて1のデータを module->forward の
引数として渡して、その戻り値をtoTensor() で変換したものが最終的な戻り値 (at::Tensor) になっています。
モデルに対して、torch::jit::IValue として、(1, 3, 224, 244) の要素がすべて1のデータを module->forward の
引数として渡して、その戻り値をtoTensor() で変換したものが最終的な戻り値 (at::Tensor) になっています。
// Create a vector of inputs. std::vector<torch::jit::IValue> inputs; inputs.push_back(torch::ones({1, 3, 224, 224})); // Execute the model and turn its output into a tensor. at::Tensor output = module->forward(inputs).toTensor(); std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
「PyTorch + XLA」では、ここで、script::Module が使われています。
void InitXlaModuleBindings(py::module m) { py::class_<XlaModule, std::shared_ptr<XlaModule>>(m, "XlaModule") .def(py::init([](const std::shared_ptr<script::Module> module, bool use_full_conv_precision, bool differentiate) { return std::make_shared<XlaModule>(module, use_full_conv_precision, differentiate); }), py::arg("module"), py::arg("use_full_conv_precision") = false, py::arg("differentiate") = true)
script::Modulle となっているのは、その前に、
namespace torch { namespace jit { namespace {
になっていますので、torch::jit::script::Module になっているわけですね。