昨日は、LOADING A PYTORCH MODEL IN C++として、PyTorchのtorch.jit.ScriptModule を C++コードで使うためにはどうすればいいかを見てみました。
今日は、せっかくなので、Python側のTorchScriptについても記録のために残しておきます。
下記のコードが torch.jit.ScriptModule クラスを使って定義したクラスの例。
import torch class MyModule(torch.jit.ScriptModule): def __init__(self, N, M): super(MyModule, self).__init__() self.weight = torch.nn.Parameter(torch.rand(N, M)) @torch.jit.script_method def forward(self, input): return self.weight.mv(input)
もう一つの例がこちら。
from torch.jit import ScriptModule, script_method になっているので、
クラス名は、torch.jit.ScriptModule から ScriptModule に、
@torch.jit.script_method から @script_method になっています。
クラス名は、torch.jit.ScriptModule から ScriptModule に、
@torch.jit.script_method から @script_method になっています。
import torch import torch.nn as nn import torch.nn.functional as F from torch.jit import ScriptModule, script_method, trace class MyScriptModule(ScriptModule): def __init__(self): super(MyScriptModule, self).__init__() # trace produces a ScriptModule's conv1 and conv2 self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) @script_method def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input
シリアライズしたモデルをファイルにストアするには、
save(filename)
シリアライズしたモデルをファイルからロードするには、
torch.jit.load(f, map_location=None)
モデルのロードは、下記のようにすれば、OK。。。
>>> torch.jit.load('scriptmodule.pt') # Load ScriptModule from io.BytesIO object >>> with open('scriptmodule.pt', 'rb') as f: buffer = io.BytesIO(f.read()) # Load all tensors to the original device >>> torch.jit.load(buffer) # Load all tensors onto CPU, using a device >>> torch.jit.load(buffer, map_location=torch.device('cpu')) # Load all tensors onto CPU, using a string >>> torch.jit.load(buffer, map_location='cpu')
torch.nn.Module や 関数(func) から torch.jit.ScriptModule に変換するのが、torch.jit.trace のようです。
第一引数が torch.nn.Module または func で、第二引数が 入力データ。それ以降はオプション。
第一引数が torch.nn.Module または func で、第二引数が 入力データ。それ以降はオプション。
torch.jit.trace(func, example_inputs, optimize=True, check_trace=True, check_inputs=None, check_tolerance=1e-05, _force_outplace=False)