@Vengineerの戯言 : Twitter
SystemVerilogの世界へようこそ、すべては、SystemC v0.9公開から始まった
これ。
pybind11にて、jax を jit で使えるようにするものっぽい。
tensorflow/compiler/xla/python/xla.ccの中で、
BuildJaxjitSubmodule(m);
を呼び出している。
具体的には、"jit" にて、CompiledFunction を生成して、__call__ で生成した CompiledFunction::Call を呼び出す感じ。
Google、計算部分に、jax を上手く利用する感じですかね。
jax の jit は、「Compilation with jit」ということろに説明がありますね。
import jax.numpy as jnp
from jax import jitdef slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0x = jnp.ones*1
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
Titan X上で、通常のGPU動作に対して、3倍速いと。。。
記録として、
numpy をお使いの皆様、
— Vengineer@アマゾンプライムで映画三昧 (@Vengineer) 2020年8月15日
Google の jax を使えば、numpy 互換でなおかつ、jit を使えば、速くなるようです。
Titan Xにて、jit 有り無しで 3 倍違うと (この例では)https://t.co/1XuowkXu5C
*1:5000, 5000