Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

Google JAX と TensorFlow の XRT


このブログに、Google JAXの記事を書いたのが、2018年3月15日のJAX

その実装コードも既に公開されています。JAX: Autograd and XLA

前回のブログでは、バックエンドに、XLA を使うと、書きましたが、

公開されたソースコードを眺めていたら、jax/lib/xla_bridge.py には、

flags.DEFINE_enum(
    'jax_xla_backend', 'xla', ['xla', 'xrt'],
    'Either "xla" for the XLA service directly, or "xrt" for an XRT backend.')
とか
def _get_xla_client(backend_name, platform_name, replica_count):
  """Configures and returns a handle to the XLA client.
  Args:
    backend_name: backend name, 'xla' or 'xrt'
    platform_name: platform name for XLA backend
    replica_count: number of computation replicas with which to configure the
      backend library.
  Returns:
    A client library module, or an object that behaves identically to one.
  """
  global _platform_name
  xla_client.initialize_replica_count(replica_count)
  if backend_name == 'xla':
    if platform_name:
      xla_client.initialize_platform_name(platform_name)
      _platform_name = platform_name
    else:
      try:
        xla_client.initialize_platform_name('CUDA')
        _platform_name = 'CUDA'
      except RuntimeError:
        warnings.warn('No GPU found, falling back to CPU.')
        xla_client.initialize_platform_name('Host')
        _platform_name = 'Host'
  return xla_client
とかありますね。

バックエンドには、XLA だけでなく、XRT も使えるようです。

xla_client って、from jaxlib import xla_client で、jaxlib からの import になっていて、
jaxlib ディレクトリには、xla_client.py なるファイルはありません。


build/install_xla_in_source_tree.sh
sed \
  -e 's/from tensorflow.compiler.xla.python import pywrap_xla as c_api/from . import pywrap_xla as c_api/' \
  -e 's/from tensorflow.compiler.xla import xla_data_pb2/from . import xla_data_pb2/' \
  -e '/from tensorflow.compiler.xla.service import hlo_pb2/d' \
  < "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_client.py)" \
  > "${TARGET}/jaxlib/xla_client.py"
から、TensorFlow XLA 内の python/xla_client.py の内容を変更したものですね。

このコードは、2018年3月1日のブログ、TensorFlow r1.6 & local Python XLA clientに書いたものです。
Slideshareにもアップしました。TensorFlow local Python XLA client
local Python XLA client を使っても利用できるOpの種類はかなり限定的です。

そこで、JAXでもXLAだけでなく、XRTを利用するようになったのでしょうかね。どうなんでしょうかね。

ということで、2018年12月31日、最後のブログは、TensorFlow XLA 関連記事にしておきました。