Vengineerの妄想

人生を妄想しています。

Google ColabでXLA in Python


Google/jaxのnotebooksに下記のファイルがアップされましたよ。


Google/jax では、TensorFlow XLAにPythonでアクセスするコードを独自に持っているのではなく、
TensorFlowの中のXLAになるPythonコードをちょっと変更して使っています。

jax/buildの install_xla_in_source_tree.sh というシェルスクリプトの最後で下記のようにTensorFlowのコードから
必要な部分だけ(tensorfow/compiler/xla/python/xla_client.py)を取り出して、ちょっとだけ変更しています。

sed \
  -e 's/from tensorflow.compiler.xla.python import pywrap_xla as c_api/from . import pywrap_xla as c_api/' \
  -e 's/from tensorflow.compiler.xla import xla_data_pb2/from . import xla_data_pb2/' \
  -e '/from tensorflow.compiler.xla.service import hlo_pb2/d' \
  < "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_client.py)" \
  > "${TARGET}/jaxlib/xla_client.py"

変更後のファイルを、jaxlib/xla_client.py にストアしています。

このファイルが XLA in Python のメイン部になります。

XLA in Pythonに書いてある内容は、Google Colabにて実行できます。ただし、TPUではなく、GPU(CUDA)ですが。。。

    $ pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.7-cp36-none-linux_x86_64.whl

from jaxlib import xla_client
とすることで、XLAへのアクセスができるようになり、

下記のようなコードにて、xla_clientのComputationBuilderクラス内の各メソッドを使ってXLAにアクセスすることになります。
# make a computation builder
c = xla_client.ComputationBuilder("simple_scalar")

# define a parameter shape and parameter
param_shape = xla_client.Shape.array_shape(onp.float32, ())
x = c.ParameterWithShape(param_shape)

# define computation graph
y = c.Sin(x)

# build computation graph
# Keep in mind that incorrectly constructed graphs can cause 
# your notebook kernel to crash!
computation = c.Build()

# compile graph based on shape
compiled_computation = computation.Compile([param_shape,])

# define a host variable with above parameter shape
host_input = onp.array(3.0, dtype=onp.float32)

# place host variable on device and execute
device_input = xla_client.LocalBuffer.from_pyval(host_input)
device_out = compiled_computation.Execute([device_input ,])

# retrive the result
device_out.to_py()

おっと、xla_client.py にJAX用のコードが入れられたのね。
# Version of the XLA Python client.
#
# JAX packages the XLA python plugin as a binary pip module (jaxlib) that is
# packaged separately from the Python code that consumes it (jax).
#
# We occasionally need to make backwards-incompatible changes to jaxlib, in
# which case we need to be able to detect when incompatible versions are
# installed.
def version():
  return (0, 1, 7)

このバージョンが
    $ pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.7-cp36-none-linux_x86_64.whl
の jaxlib-0.1.7 に対応するのね。