Googleから、XLAではなく、JAXというものが。。。
論文から引用します。
JAX is built atop the same tracing library used within Autograd, which, being designed for self-closure, recognizes its own operations as primitives. JAX also has Numpy’s numerical functions among its primitives. As a result, it generates code for Python functions written in familiar Numpy and that involve arbitrary-order forward- and reverse-mode automatic differentiation. On the back end, JAX uses XLA for array-level program optimization and code generation. Whereas other systems focus on providing easy access to a fixed set of hand-written, target-specific numerical kernels, JAX provides a means of composition for all of XLA’s supported target architectures: by trace-compiling PSC routines, JAX automatically stages out new kernels from existing ones. The acronym JAX stands for “just after execution”, since to compile a function we first monitor its execution once in Python.
それから
Cloud TPU でも動くって、おまけに、スケールするって。。。