今回が最後。
モデルの実行部分。
def test_resnet(): # Load the model model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet18') classes = 1000 device = 'cpu' model = DLRModel(model_path, device) # Run the model image = np.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dog.npy')).astype(np.float32) #flatten within a input array input_data = {'data': image} probabilities = model.run(input_data) #need to be a list of input arrays matching input names assert probabilities[0].argmax() == 111
の
probabilities = model.run(input_data) #need to be a list of input arrays matching input namesの run メソッド。
def run(self, input_values): out = [] # set input(s) if isinstance(input_values, (np.ndarray, np.generic)): # Treelite model or single input tvm/treelite model. # Treelite has a dummy input name 'data'. if self.input_names: self._set_input(self.input_names[0], input_values) elif isinstance(input_values, dict): # TVM model for key, value in input_values.items(): if self.input_names and key not in self.input_names: raise ValueError("%s is not a valid input name." % key) self._set_input(key, value) else: raise ValueError("input_values must be of type dict (tvm model) " + "or a np.ndarray/generic (representing treelite models)") # run model self._run() # get output for i in range(self.num_outputs): ith_out = self._get_output(i) out.append(ith_out) return out
で、_run メソッドを呼んでいます。
def _run(self): """A light wrapper to call run in the DLR backend.""" self._check_call(self.lib.RunDLRModel(byref(self.handle)))
RunDLRModelメソッドを呼んでいますね。
extern "C" int RunDLRModel(DLRModelHandle *handle) { API_BEGIN(); static_cast<DLRModel *>(*handle)->Run(); API_END(); }
DLRModelクラス の Runメソッドを呼んでいます。
void DLRModel::Run() { if (backend_ == DLRBackend::kTVM) { // get the function from the module(run it) tvm::runtime::PackedFunc run = tvm_module_->GetFunction("run"); run(); } else if (backend_ == DLRBackend::kTREELITE) { // NOTE: Assume batch size is 1. However, Treelite internally can support // arbitrary batch size size_t out_result_size; CHECK_EQ(TreelitePredictorPredictInst(treelite_model_, treelite_input_.get(), 0, treelite_output_.get(), &out_result_size), 0) << TreeliteGetLastError(); } }
にて、TVMモデルの時は、GetFunctionメソッドにてrunメソッド呼んでいますね。
} else if (name == "run") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); }
最後は、GraphRuntime::Runメソッドを実行。
void GraphRuntime::Run() { // setup the array and requirements. for (size_t i = 0; i < op_execs_.size(); ++i) { if (op_execs_[i]) op_execs_[i](); } }ひとつづつ、オペを実行しているだけですね。