Vengineerの妄想

人生を妄想しています。

Dynamically loadable XLA pluginの内容


先週金曜日(2018年4月6日)のDynamically loadable XLA plugin proposalの続きで、ソースコードの調べています。

TensorFlow XLAでは、下記のように、Platform、Compiler、Computation Placer、Transfer Manager、Deviceの登録が必要です。
  1)、Platformの登録 (Executorの登録含む)
  2)、Kernelの登録
  3)、Compilerの登録
  4)、Computation Placerの登録
  5)、Transfer Managerの登録
  6)、Deviceの登録

dynamically loadable XLA plugin でも同じことをしています。


このファイルの最後で、
引用
// The volatile is key here as otherwise the optimizer removes the call to
// InitPluginModule() and the function as well.
volatile bool module_initialized = InitPluginModule();

と、InitPluginModule 関数が呼ばれています。

この InitPluginModule 関数は、下記のようになっています。
bool InitPluginModule() {

  // We are running as part of TensorFlow python environment
  auto tf_root = xla::dynamic_plugin::GetTensorflowRoot();
  auto plugin_directory = tf_root + "/plugins/";
  // Get the list of plugin.so files
  std::string pattern = plugin_directory + "*.so";

  std::vector<std::string> files;
  auto result = tensorflow::Env::Default()->GetMatchingPaths(pattern, &files);
  if (!result.ok() || files.size() == 0) {
    VLOG(1) << "No dynamic XLA plugins found in: " << plugin_directory;
    return false;
  }

  // Load the first plugin for now as loading multiple plugis do not work
  // yet
  tensorflow::LoadDynamicPlugin(files[0]);
}

最初に、TensorFlowのルートディレクトリ名を xla::dynamic_plugin::GetTensorflowRoot() 関数で獲得し、
その下にある plugins ディレクトリにある共有ライブラリ(*.so)の中から最初に見つかったライブラリ(files[0])を
使って、tensorflow::LoadDynamicPlugin 関数を呼んでいます。

この tensorflow::LoadDynamicPlugin 関数の中で上で説明した登録を以下のように行っています。
 ・GetPluginData関数を外部のPlugin Dataを獲得する (=> ngrap tensorflow bridgeのライブラリから獲得する)

   xla::plugin::Info plugin_info = GetPluginData();

 ・GetPluginData関数で獲得したplugin_info構造体のDeviceInfoメンバーからPLUGINNAMEを獲得する

    auto device_info = DeviceInfo();

    VLOG(1) << "PLUGIN NAME: " << device_info.PLATFORM_NAME;

  ・Kernel登録

    REGISTER_XLA_LAUNCH_KERNEL(device_info.XLA_DEVICE_NAME, tensorflow::XlaLocalLaunchOp, supported_data_types);
    REGISTER_XLA_DEVICE_KERNELS(device_info.XLA_DEVICE_NAME, supported_data_types);
    REGISTER_XLA_BACKEND(device_info.XLA_DEVICE_JIT_NAME, supported_data_types, OpFilter);

  ・Platformの登録 (PlatformAdapter内で、ExecutorAdapterを初期化)

     std::unique_ptr<perftools::gputools::Platform> platform(
      new xla::dynamic_plugin::PlatformAdapter(
          device_info.PLATFORM_NAME, kPluginPlatformId,
          device_info.visible_device_count));

     perftools::gputools::MultiPlatformManager::RegisterPlatform(std::move(platform));

 ・Platformの初期化

   auto status = plugin_info.Init(kPluginPlatformId);
     if (!status) {
        LOG(WARNING) << "Plugin initialization failed";
        return false;
     }

 ・Compilerの登録

    xla::Compiler::RegisterCompilerFactory(kPluginPlatformId, [=]() {
      return xla::MakeUnique<xla::dynamic_plugin::CompilerAdapter>(
            kPluginPlatformId, plugin_info);
    });

 ・Computation Placerの登録

  xla::ComputationPlacer::RegisterComputationPlacer(
                           kPluginPlatformId,
                           &xla::dynamic_plugin::CompilerAdapter::CreateComputationPlacer);

 ・Transfer Managerの登録 (ちょっと長いです)

    xla::dynamic_plugin::TransferManagerAdapter::Init(kPluginPlatformId);

    xla::dynamic_plugin::TransferManagerAdapter* new_transfer_manager{nullptr};
    const perftools::gputools::Platform* this_platform;
    auto statusor = perftools::gputools::MultiPlatformManager::PlatformWithId(kPluginPlatformId);
    this_platform = statusor.ValueOrDie();

    xla::StatusOr<xla::TransferManager*> s = xla::TransferManager::GetForPlatform(this_platform);
    new_transfer_manager = (xla::dynamic_plugin::TransferManagerAdapter*)s.ValueOrDie();
  
    auto plugin_transfer_manager = GetTransferManager();

    new_transfer_manager->SetImplementation(plugin_transfer_manager);

 ・Deviceの登録

    DeviceFactory::Register( device_info.XLA_DEVICE_NAME,
                             new DeviceFactoryAdapter(device_info.PLATFORM_NAME,
                                                      device_info.XLA_DEVICE_NAME,
                                                      device_info.XLA_DEVICE_JIT_NAME),
                             device_info.device_priority);

というように、オリジナルのXLAと同じように各種登録していますね。