Vengineerの妄想(準備期間)

人生は短いけど、長いです。人生を楽しみましょう!

Dynamically loadable XLA pluginの内容


先週金曜日(2018年4月6日)のDynamically loadable XLA plugin proposalの続きで、ソースコードの調べています。

TensorFlow XLAでは、下記のように、Platform、Compiler、Computation Placer、Transfer Manager、Deviceの登録が必要です。
  1)、Platformの登録 (Executorの登録含む)
  2)、Kernelの登録
  3)、Compilerの登録
  4)、Computation Placerの登録
  5)、Transfer Managerの登録
  6)、Deviceの登録

dynamically loadable XLA plugin でも同じことをしています。


このファイルの最後で、
引用
// The volatile is key here as otherwise the optimizer removes the call to
// InitPluginModule() and the function as well.
volatile bool module_initialized = InitPluginModule();

と、InitPluginModule 関数が呼ばれています。

この InitPluginModule 関数は、下記のようになっています。
bool InitPluginModule() {

  // We are running as part of TensorFlow python environment
  auto tf_root = xla::dynamic_plugin::GetTensorflowRoot();
  auto plugin_directory = tf_root + "/plugins/";
  // Get the list of plugin.so files
  std::string pattern = plugin_directory + "*.so";

  std::vector<std::string> files;
  auto result = tensorflow::Env::Default()->GetMatchingPaths(pattern, &files);
  if (!result.ok() || files.size() == 0) {
    VLOG(1) << "No dynamic XLA plugins found in: " << plugin_directory;
    return false;
  }

  // Load the first plugin for now as loading multiple plugis do not work
  // yet
  tensorflow::LoadDynamicPlugin(files[0]);
}

最初に、TensorFlowのルートディレクトリ名を xla::dynamic_plugin::GetTensorflowRoot() 関数で獲得し、
その下にある plugins ディレクトリにある共有ライブラリ(*.so)の中から最初に見つかったライブラリ(files[0])を
使って、tensorflow::LoadDynamicPlugin 関数を呼んでいます。

この tensorflow::LoadDynamicPlugin 関数の中で上で説明した登録を以下のように行っています。
 ・GetPluginData関数を外部のPlugin Dataを獲得する (=> ngrap tensorflow bridgeのライブラリから獲得する)

   xla::plugin::Info plugin_info = GetPluginData();

 ・GetPluginData関数で獲得したplugin_info構造体のDeviceInfoメンバーからPLUGINNAMEを獲得する

    auto device_info = DeviceInfo();

    VLOG(1) << "PLUGIN NAME: " << device_info.PLATFORM_NAME;

  ・Kernel登録

    REGISTER_XLA_LAUNCH_KERNEL(device_info.XLA_DEVICE_NAME, tensorflow::XlaLocalLaunchOp, supported_data_types);
    REGISTER_XLA_DEVICE_KERNELS(device_info.XLA_DEVICE_NAME, supported_data_types);
    REGISTER_XLA_BACKEND(device_info.XLA_DEVICE_JIT_NAME, supported_data_types, OpFilter);

  ・Platformの登録 (PlatformAdapter内で、ExecutorAdapterを初期化)

     std::unique_ptr<perftools::gputools::Platform> platform(
      new xla::dynamic_plugin::PlatformAdapter(
          device_info.PLATFORM_NAME, kPluginPlatformId,
          device_info.visible_device_count));

     perftools::gputools::MultiPlatformManager::RegisterPlatform(std::move(platform));

 ・Platformの初期化

   auto status = plugin_info.Init(kPluginPlatformId);
     if (!status) {
        LOG(WARNING) << "Plugin initialization failed";
        return false;
     }

 ・Compilerの登録

    xla::Compiler::RegisterCompilerFactory(kPluginPlatformId, [=]() {
      return xla::MakeUnique<xla::dynamic_plugin::CompilerAdapter>(
            kPluginPlatformId, plugin_info);
    });

 ・Computation Placerの登録

  xla::ComputationPlacer::RegisterComputationPlacer(
                           kPluginPlatformId,
                           &xla::dynamic_plugin::CompilerAdapter::CreateComputationPlacer);

 ・Transfer Managerの登録 (ちょっと長いです)

    xla::dynamic_plugin::TransferManagerAdapter::Init(kPluginPlatformId);

    xla::dynamic_plugin::TransferManagerAdapter* new_transfer_manager{nullptr};
    const perftools::gputools::Platform* this_platform;
    auto statusor = perftools::gputools::MultiPlatformManager::PlatformWithId(kPluginPlatformId);
    this_platform = statusor.ValueOrDie();

    xla::StatusOr<xla::TransferManager*> s = xla::TransferManager::GetForPlatform(this_platform);
    new_transfer_manager = (xla::dynamic_plugin::TransferManagerAdapter*)s.ValueOrDie();
  
    auto plugin_transfer_manager = GetTransferManager();

    new_transfer_manager->SetImplementation(plugin_transfer_manager);

 ・Deviceの登録

    DeviceFactory::Register( device_info.XLA_DEVICE_NAME,
                             new DeviceFactoryAdapter(device_info.PLATFORM_NAME,
                                                      device_info.XLA_DEVICE_NAME,
                                                      device_info.XLA_DEVICE_JIT_NAME),
                             device_info.device_priority);

というように、オリジナルのXLAと同じように各種登録していますね。