Google/jaxのnotebooksに下記のファイルがアップされましたよ。
Google/jax では、TensorFlow XLAにPythonでアクセスするコードを独自に持っているのではなく、
TensorFlowの中のXLAになるPythonコードをちょっと変更して使っています。
TensorFlowの中のXLAになるPythonコードをちょっと変更して使っています。
jax/buildの install_xla_in_source_tree.sh というシェルスクリプトの最後で下記のようにTensorFlowのコードから
必要な部分だけ(tensorfow/compiler/xla/python/xla_client.py)を取り出して、ちょっとだけ変更しています。
必要な部分だけ(tensorfow/compiler/xla/python/xla_client.py)を取り出して、ちょっとだけ変更しています。
sed \ -e 's/from tensorflow.compiler.xla.python import pywrap_xla as c_api/from . import pywrap_xla as c_api/' \ -e 's/from tensorflow.compiler.xla import xla_data_pb2/from . import xla_data_pb2/' \ -e '/from tensorflow.compiler.xla.service import hlo_pb2/d' \ < "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_client.py)" \ > "${TARGET}/jaxlib/xla_client.py"
変更後のファイルを、jaxlib/xla_client.py にストアしています。
このファイルが XLA in Python のメイン部になります。
$ pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.7-cp36-none-linux_x86_64.whl
from jaxlib import xla_clientとすることで、XLAへのアクセスができるようになり、
下記のようなコードにて、xla_clientのComputationBuilderクラス内の各メソッドを使ってXLAにアクセスすることになります。
# make a computation builder c = xla_client.ComputationBuilder("simple_scalar") # define a parameter shape and parameter param_shape = xla_client.Shape.array_shape(onp.float32, ()) x = c.ParameterWithShape(param_shape) # define computation graph y = c.Sin(x) # build computation graph # Keep in mind that incorrectly constructed graphs can cause # your notebook kernel to crash! computation = c.Build() # compile graph based on shape compiled_computation = computation.Compile([param_shape,]) # define a host variable with above parameter shape host_input = onp.array(3.0, dtype=onp.float32) # place host variable on device and execute device_input = xla_client.LocalBuffer.from_pyval(host_input) device_out = compiled_computation.Execute([device_input ,]) # retrive the result device_out.to_py()
おっと、xla_client.py にJAX用のコードが入れられたのね。
# Version of the XLA Python client. # # JAX packages the XLA python plugin as a binary pip module (jaxlib) that is # packaged separately from the Python code that consumes it (jax). # # We occasionally need to make backwards-incompatible changes to jaxlib, in # which case we need to be able to detect when incompatible versions are # installed. def version(): return (0, 1, 7)
このバージョンが
$ pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.7-cp36-none-linux_x86_64.whlの jaxlib-0.1.7 に対応するのね。