get_xla_client を使っているところとして、get_jax_computation_builder_classメソッドがあります。
@memoize_thunk def get_jax_computation_builder_class(): xla_base = get_xla_client().ComputationBuilder jax_base = _JaxComputationBuilderBase return type('JaxComputationBuilder', (jax_base, xla_base), {})
ベースクラスとして jax_base と xla_base を持ち、本体が空 ({}) のクラスJaxComputatiobBuilderを定義して、
それを戻り値にしていますね。
それを戻り値にしていますね。
で、クラスJaxComputatiobBuilderは、次のように定義しています。
class _JaxComputationBuilderBase(object): """Base class implementing all of JaxComputationBuilder. This class is intended to override and augment the interface of an XLA ComputationBuilder to form JaxComputationBuilder, as made clear by `get_jax_computation_builder_class`, which relies on Python's method-resolution order to set up inheritance-like behavior. The class inheritance setup is deferred because the choice of the XLA ComputationBuilder class is based on the result of `get_xla_client()`. That is, the choice is based at least on the setting of flags, which are available only after module initialization time. """ # The JAXComputationBuilder is implemented using subclassing and inheritance # (via this base class), rather than a wrap-and-delegate style, simply to # avoid having to spell out all the methods to be forwarded to a wrapped # ComputationBuilder, especially since the underlying ComputationBuilders are # likely to be revised in the future. An alternative is to generate these # forwarding methods programmatically. # Method name case follows that of the XLA ComputationBuilder # pylint: disable=invalid-name def Build(self, *args, **kwargs): return super(_JaxComputationBuilderBase, self).Build( *args, backend=_get_backend(), **kwargs) def Parameter(self, value, name=None, parameter_num=None): return super(_JaxComputationBuilderBase, self).ParameterWithShape( shape_of(value), name=name, parameter_num=parameter_num) def NumpyArrayConstant(self, value): normalized_value = normalize_to_xla_dtypes(value) return super(_JaxComputationBuilderBase, self).Constant(normalized_value) def ConstantLike(self, example_value, value): example_value = onp.asarray(example_value) return self.Constant(onp.array(value).astype(example_value.dtype)) def Constant(self, py_val): """Translate constant `py_val` to a constant for this ComputationBuilder. Args: py_val: a Python value to be translated to a constant. Returns: A representation of the constant, either a ComputationDataHandle or None """ py_type = type(py_val) if py_type in _constant_handlers: return _constant_handlers[py_type](self, py_val) else: raise TypeError("No constant handler for type: {}".format(py_type))
Buildにて、localComputionクラスを生成していますね。
def Build(self, root=None, backend=XLA_LOCAL_BACKEND): if root is not None: return LocalComputation( self._client.BuildWithRoot(root), is_compiled=False, backend=backend) else: return LocalComputation( self._client.Build(), is_compiled=False, backend=backend)
この Buildメソッドは、primitive_computation メソッドにて使われています。
@memoize def primitive_computation(prim, *shapes, **kwargs): c = xb.make_computation_builder("primitive_computation") xla_args = map(c.ParameterWithShape, shapes) xla_result = translation_rule(prim)(c, *xla_args, **kwargs) try: return c.Build() except RuntimeError as e: # try for a better error message by using the abstract_eval checks prim.abstract_eval(*map(aval_from_xla_shape, shapes), **kwargs) raise e