Vengineerの妄想(準備期間)

人生は短いけど、長いです。人生を楽しみましょう!

AWS Neo-AI DLRのソースコード解析(その4)


今回が最後。

モデルの実行部分。

def test_resnet():
    # Load the model
    model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
            'resnet18')
    classes = 1000
    device = 'cpu'
    model = DLRModel(model_path, device)

    # Run the model
    image = np.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dog.npy')).astype(np.float32)
    #flatten within a input array
    input_data = {'data': image}
    probabilities = model.run(input_data) #need to be a list of input arrays matching input names
    assert probabilities[0].argmax() == 111


    probabilities = model.run(input_data) #need to be a list of input arrays matching input names
run メソッド

    def run(self, input_values):
        out = []
        # set input(s)
        if isinstance(input_values, (np.ndarray, np.generic)):
            # Treelite model or single input tvm/treelite model.
            # Treelite has a dummy input name 'data'.
            if self.input_names:
                self._set_input(self.input_names[0], input_values)
        elif isinstance(input_values, dict):
            # TVM model
            for key, value in input_values.items():
                if self.input_names and key not in self.input_names:
                    raise ValueError("%s is not a valid input name." % key)
                self._set_input(key, value)
        else:
            raise ValueError("input_values must be of type dict (tvm model) " +
                             "or a np.ndarray/generic (representing treelite models)")
        # run model
        self._run()
        # get output
        for i in range(self.num_outputs):
            ith_out = self._get_output(i)
            out.append(ith_out)
        return out

で、_run メソッドを呼んでいます。

    def _run(self):
        """A light wrapper to call run in the DLR backend."""
        self._check_call(self.lib.RunDLRModel(byref(self.handle)))

RunDLRModelメソッドを呼んでいますね。

extern "C" int RunDLRModel(DLRModelHandle *handle) {
  API_BEGIN();
  static_cast<DLRModel *>(*handle)->Run();
  API_END();
}

DLRModelクラス の Runメソッドを呼んでいます。

void DLRModel::Run() {
  if (backend_ == DLRBackend::kTVM) {
    // get the function from the module(run it)
    tvm::runtime::PackedFunc run = tvm_module_->GetFunction("run");
    run();
  } else if (backend_ == DLRBackend::kTREELITE) {
    // NOTE: Assume batch size is 1. However, Treelite internally can support
    //       arbitrary batch size
    size_t out_result_size;
    CHECK_EQ(TreelitePredictorPredictInst(treelite_model_, treelite_input_.get(),
                                          0, treelite_output_.get(),
                                          &out_result_size), 0)
      << TreeliteGetLastError();
  }
}

にて、TVMモデルの時は、GetFunctionメソッドにてrunメソッド呼んでいますね。
  } else if (name == "run") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        this->Run();
      });
  }

最後は、GraphRuntime::Runメソッドを実行。
void GraphRuntime::Run() {
  // setup the array and requirements.
  for (size_t i = 0; i < op_execs_.size(); ++i) {
    if (op_execs_[i]) op_execs_[i]();
  }
}
ひとつづつ、オペを実行しているだけですね。