Vengineerの妄想

人生を妄想しています。

TorchScript


昨日は、LOADING A PYTORCH MODEL IN C++として、PyTorchのtorch.jit.ScriptModule を C++コードで使うためにはどうすればいいかを見てみました。

今日は、せっかくなので、Python側のTorchScriptについても記録のために残しておきます。


下記のコードが torch.jit.ScriptModule クラスを使って定義したクラスの例。

import torch
class MyModule(torch.jit.ScriptModule):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    @torch.jit.script_method
    def forward(self, input):
        return self.weight.mv(input)

もう一つの例がこちら。

from torch.jit import ScriptModule, script_method になっているので、
クラス名は、torch.jit.ScriptModule から ScriptModule に、
@torch.jit.script_method から @script_method になっています。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import ScriptModule, script_method, trace

class MyScriptModule(ScriptModule):
    def __init__(self):
        super(MyScriptModule, self).__init__()
        # trace produces a ScriptModule's conv1 and conv2
        self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    @script_method
    def forward(self, input):
      input = F.relu(self.conv1(input))
      input = F.relu(self.conv2(input))
      return input

シリアライズしたモデルをファイルにストアするには、

    save(filename)

シリアライズしたモデルをファイルからロードするには、

    torch.jit.load(f, map_location=None)

fは、ファイル名。map_locationは、デバイスを示す文字列('cpu’, ‘cuda:0')、またはデバイス名(torch.device(‘cpu’))

モデルのロードは、下記のようにすれば、OK。。。
>>> torch.jit.load('scriptmodule.pt')
# Load ScriptModule from io.BytesIO object
>>> with open('scriptmodule.pt', 'rb') as f:
        buffer = io.BytesIO(f.read())
# Load all tensors to the original device
>>> torch.jit.load(buffer)
# Load all tensors onto CPU, using a device
>>> torch.jit.load(buffer, map_location=torch.device('cpu'))
# Load all tensors onto CPU, using a string
>>> torch.jit.load(buffer, map_location='cpu')

torch.nn.Module や 関数(func) から torch.jit.ScriptModule に変換するのが、torch.jit.trace のようです。
第一引数が torch.nn.Module または func で、第二引数が 入力データ。それ以降はオプション。

    torch.jit.trace(func, 
                    example_inputs, 
                    optimize=True, 
                    check_trace=True, 
                    check_inputs=None, 
                    check_tolerance=1e-05, 
                    _force_outplace=False)