まずは、テストコード。
load_and_run_tvm_model.py。この中のResnetの例。
def test_resnet(): # Load the model model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet18') classes = 1000 device = 'cpu' model = DLRModel(model_path, device) # Run the model image = np.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dog.npy')).astype(np.float32) #flatten within a input array input_data = {'data': image} probabilities = model.run(input_data) #need to be a list of input arrays matching input names assert probabilities[0].argmax() == 111
昨日の
input_shape = {'data': [1, 3, 224, 224]} # A single RGB 224x224 image output_shape = [1, 1000] # The probability for each one of the 1,000 classes device = 'cpu' # Go, Raspberry Pi, go! model = DLRModel('resnet50', input_shape, output_shape, device)とは、微妙に違うね。モデルの指定が名前と、入力シェイプと出力シェイプからモデルファイルのパスに変更。
DLRModelの定義部は、ここ
def __init__(self, tar_path, dev_type='cpu', dev_id=0): if not os.path.exists(tar_path): raise ValueError("tar_path %s doesn't exist" % tar_path) self.handle = c_void_p() libpath = os.path.join(os.path.dirname( os.path.abspath(os.path.expanduser(__file__))), 'libdlr.so') self.lib = cdll.LoadLibrary(libpath) self.lib.DLRGetLastError.restype = ctypes.c_char_p device_table = { 'cpu': 1, 'gpu': 2, 'opencl': 4, } self._check_call(self.lib.CreateDLRModel(byref(self.handle), c_char_p(tar_path.encode()), c_int(device_table[dev_type]), c_int(dev_id))) self.num_inputs = self._get_num_inputs() self.input_names = [] for i in range(self.num_inputs): self.input_names.append(self._get_input_name(i)) self.num_outputs = self._get_num_outputs() self.output_shapes = [] self.output_size_dim = [] for i in range(self.num_outputs): shape = self._get_output_shape(i) self.output_shapes.append(shape)
self.lib.CreateDLRModelは、ここ(dlr.cc)
extern "C" int CreateDLRModel(DLRModelHandle* handle, const char* model_path, int dev_type, int dev_id) { API_BEGIN(); const std::string model_path_string(model_path); DLContext ctx; ctx.device_type = static_cast<DLDeviceType>(dev_type); ctx.device_id = dev_id; DLRModel *model = new DLRModel(model_path_string, ctx); *handle = model; API_END(); }
DLRModelクラスのインスタンスを生成しているだけですよね。
明日に続く。