Vengineerの妄想

人生を妄想しています。

AWS Neo-AI DLRのソースコード解析(その3)



TVMバックエンドの時は、SetupTVMModuleにて、TVMモデルを実行するための環境を構築しています。

void DLRModel::SetupTVMModule(const std::string& model_path) {
  ModelPath paths = get_tvm_paths(model_path);
  std::ifstream jstream(paths.model_json);
  std::stringstream json_blob;
  json_blob << jstream.rdbuf();
  std::ifstream pstream(paths.params);
  std::stringstream param_blob;
  param_blob << pstream.rdbuf();

  tvm::runtime::Module module;
  if (!IsFileEmpty(paths.model_lib)){
    module = tvm::runtime::Module::LoadFromFile(paths.model_lib);
  }
  tvm_graph_runtime_ =
    std::make_shared<tvm::runtime::GraphRuntime>();
  tvm_graph_runtime_->Init(json_blob.str(), module, {ctx_});
  tvm_graph_runtime_->LoadParams(param_blob.str());

  tvm_module_ = std::make_shared<tvm::runtime::Module>(
      tvm::runtime::Module(tvm_graph_runtime_));

  // Save the number of inputs. It excludes inputs that could be obtained
  // through the param file, such as weights.
  num_inputs_ = tvm_graph_runtime_->NumInputs() - GetWeightNames().size();
  std::vector<std::string> input_names;
  for (int i = 0; i < num_inputs_; i++)  {
    input_names.push_back(tvm_graph_runtime_->GetInputName(i));
  }
  std::vector<std::string> weight_names = tvm_graph_runtime_->GetWeightNames();
  std::set_difference(input_names.begin(), input_names.end(),
                      weight_names.begin(), weight_names.end(),
                      std::inserter(input_names_, input_names_.begin()));

  // Get the number of output and reserve space to save output tensor
  // pointers.
  num_outputs_ = tvm_graph_runtime_->NumOutputs();
    outputs_.resize(num_outputs_);
  for (int i = 0; i < num_outputs_; i++) {
    tvm::runtime::NDArray output = tvm_graph_runtime_->GetOutput(i);
    outputs_[i] = output.operator->();
  }
}


  tvm::runtime::Module module;
  if (!IsFileEmpty(paths.model_lib)){
    module = tvm::runtime::Module::LoadFromFile(paths.model_lib);
  }
  tvm_graph_runtime_ =
    std::make_shared<tvm::runtime::GraphRuntime>();
  tvm_graph_runtime_->Init(json_blob.str(), module, {ctx_});

がモデルファイルからロード(Module::LoadFromFile)して、ランタイムの初期化( Init)をしています。