Vengineerの妄想(準備期間)

人生は短いけど、長いです。人生を楽しみましょう!

TensorFlow XLA に jax_jit なるコマンドが追加された

@Vengineerの戯言 : Twitter
SystemVerilogの世界へようこそすべては、SystemC v0.9公開から始まった 

これ。

github.com

pybind11にて、jax を jit で使えるようにするものっぽい。

tensorflow/compiler/xla/python/xla.ccの中で、

  BuildJaxjitSubmodule(m);

を呼び出している。 

 

具体的には、"jit" にて、CompiledFunction を生成して、__call__ で生成した CompiledFunction::Call を呼び出す感じ。

 

Google、計算部分に、jax を上手く利用する感じですかね。

github.com

jax の jit は、「Compilation with jit」ということろに説明がありますね。

import jax.numpy as jnp
from jax import jit

def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0

x = jnp.ones*1
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)

 Titan X上で、通常のGPU動作に対して、3倍速いと。。。

 

記録として、

 

 

*1:5000, 5000