TensorFlow XLAでは、下記のように、Platform、Compiler、Computation Placer、Transfer Manager、Deviceの登録が必要です。
1)、Platformの登録 (Executorの登録含む) 2)、Kernelの登録 3)、Compilerの登録 4)、Computation Placerの登録 5)、Transfer Managerの登録 6)、Deviceの登録
dynamically loadable XLA plugin でも同じことをしています。
このファイルの最後で、
引用 // The volatile is key here as otherwise the optimizer removes the call to // InitPluginModule() and the function as well. volatile bool module_initialized = InitPluginModule();
と、InitPluginModule 関数が呼ばれています。
この InitPluginModule 関数は、下記のようになっています。
bool InitPluginModule() { // We are running as part of TensorFlow python environment auto tf_root = xla::dynamic_plugin::GetTensorflowRoot(); auto plugin_directory = tf_root + "/plugins/"; // Get the list of plugin.so files std::string pattern = plugin_directory + "*.so"; std::vector<std::string> files; auto result = tensorflow::Env::Default()->GetMatchingPaths(pattern, &files); if (!result.ok() || files.size() == 0) { VLOG(1) << "No dynamic XLA plugins found in: " << plugin_directory; return false; } // Load the first plugin for now as loading multiple plugis do not work // yet tensorflow::LoadDynamicPlugin(files[0]); }
最初に、TensorFlowのルートディレクトリ名を xla::dynamic_plugin::GetTensorflowRoot() 関数で獲得し、
その下にある plugins ディレクトリにある共有ライブラリ(*.so)の中から最初に見つかったライブラリ(files[0])を
使って、tensorflow::LoadDynamicPlugin 関数を呼んでいます。
その下にある plugins ディレクトリにある共有ライブラリ(*.so)の中から最初に見つかったライブラリ(files[0])を
使って、tensorflow::LoadDynamicPlugin 関数を呼んでいます。
この tensorflow::LoadDynamicPlugin 関数の中で上で説明した登録を以下のように行っています。
・GetPluginData関数を外部のPlugin Dataを獲得する (=> ngrap tensorflow bridgeのライブラリから獲得する) xla::plugin::Info plugin_info = GetPluginData(); ・GetPluginData関数で獲得したplugin_info構造体のDeviceInfoメンバーからPLUGINNAMEを獲得する auto device_info = DeviceInfo(); VLOG(1) << "PLUGIN NAME: " << device_info.PLATFORM_NAME; ・Kernel登録 REGISTER_XLA_LAUNCH_KERNEL(device_info.XLA_DEVICE_NAME, tensorflow::XlaLocalLaunchOp, supported_data_types); REGISTER_XLA_DEVICE_KERNELS(device_info.XLA_DEVICE_NAME, supported_data_types); REGISTER_XLA_BACKEND(device_info.XLA_DEVICE_JIT_NAME, supported_data_types, OpFilter); ・Platformの登録 (PlatformAdapter内で、ExecutorAdapterを初期化) std::unique_ptr<perftools::gputools::Platform> platform( new xla::dynamic_plugin::PlatformAdapter( device_info.PLATFORM_NAME, kPluginPlatformId, device_info.visible_device_count)); perftools::gputools::MultiPlatformManager::RegisterPlatform(std::move(platform)); ・Platformの初期化 auto status = plugin_info.Init(kPluginPlatformId); if (!status) { LOG(WARNING) << "Plugin initialization failed"; return false; } ・Compilerの登録 xla::Compiler::RegisterCompilerFactory(kPluginPlatformId, [=]() { return xla::MakeUnique<xla::dynamic_plugin::CompilerAdapter>( kPluginPlatformId, plugin_info); }); ・Computation Placerの登録 xla::ComputationPlacer::RegisterComputationPlacer( kPluginPlatformId, &xla::dynamic_plugin::CompilerAdapter::CreateComputationPlacer); ・Transfer Managerの登録 (ちょっと長いです) xla::dynamic_plugin::TransferManagerAdapter::Init(kPluginPlatformId); xla::dynamic_plugin::TransferManagerAdapter* new_transfer_manager{nullptr}; const perftools::gputools::Platform* this_platform; auto statusor = perftools::gputools::MultiPlatformManager::PlatformWithId(kPluginPlatformId); this_platform = statusor.ValueOrDie(); xla::StatusOr<xla::TransferManager*> s = xla::TransferManager::GetForPlatform(this_platform); new_transfer_manager = (xla::dynamic_plugin::TransferManagerAdapter*)s.ValueOrDie(); auto plugin_transfer_manager = GetTransferManager(); new_transfer_manager->SetImplementation(plugin_transfer_manager); ・Deviceの登録 DeviceFactory::Register( device_info.XLA_DEVICE_NAME, new DeviceFactoryAdapter(device_info.PLATFORM_NAME, device_info.XLA_DEVICE_NAME, device_info.XLA_DEVICE_JIT_NAME), device_info.device_priority);
というように、オリジナルのXLAと同じように各種登録していますね。