Vengineerの妄想

人生を妄想しています。

Google/jax のソースコード解析(その3)


get_xla_client を使っているところとして、get_jax_computation_builder_classメソッドがあります。
@memoize_thunk
def get_jax_computation_builder_class():
  xla_base = get_xla_client().ComputationBuilder
  jax_base = _JaxComputationBuilderBase
  return type('JaxComputationBuilder', (jax_base, xla_base), {})


ベースクラスとして jax_base と xla_base を持ち、本体が空 ({}) のクラスJaxComputatiobBuilderを定義して、
それを戻り値にしていますね。

で、クラスJaxComputatiobBuilderは、次のように定義しています
class _JaxComputationBuilderBase(object):
  """Base class implementing all of JaxComputationBuilder.
  This class is intended to override and augment the interface of an XLA
  ComputationBuilder to form JaxComputationBuilder, as made clear by
  `get_jax_computation_builder_class`, which relies on Python's
  method-resolution order to set up inheritance-like behavior. The class
  inheritance setup is deferred because the choice of the XLA ComputationBuilder
  class is based on the result of `get_xla_client()`. That is, the choice is
  based at least on the setting of flags, which are available only after module
  initialization time.
  """
  # The JAXComputationBuilder is implemented using subclassing and inheritance
  # (via this base class), rather than a wrap-and-delegate style, simply to
  # avoid having to spell out all the methods to be forwarded to a wrapped
  # ComputationBuilder, especially since the underlying ComputationBuilders are
  # likely to be revised in the future. An alternative is to generate these
  # forwarding methods programmatically.

  # Method name case follows that of the XLA ComputationBuilder
  # pylint: disable=invalid-name

  def Build(self, *args, **kwargs):
    return super(_JaxComputationBuilderBase, self).Build(
        *args, backend=_get_backend(), **kwargs)

  def Parameter(self, value, name=None, parameter_num=None):
    return super(_JaxComputationBuilderBase, self).ParameterWithShape(
        shape_of(value), name=name, parameter_num=parameter_num)

  def NumpyArrayConstant(self, value):
    normalized_value = normalize_to_xla_dtypes(value)
    return super(_JaxComputationBuilderBase, self).Constant(normalized_value)

  def ConstantLike(self, example_value, value):
    example_value = onp.asarray(example_value)
    return self.Constant(onp.array(value).astype(example_value.dtype))

  def Constant(self, py_val):
    """Translate constant `py_val` to a constant for this ComputationBuilder.
    Args:
      py_val: a Python value to be translated to a constant.
    Returns:
      A representation of the constant, either a ComputationDataHandle or None
    """
    py_type = type(py_val)
    if py_type in _constant_handlers:
      return _constant_handlers[py_type](self, py_val)
    else:
      raise TypeError("No constant handler for type: {}".format(py_type))

Buildにて、localComputionクラスを生成していますね。
  def Build(self, root=None, backend=XLA_LOCAL_BACKEND):
    if root is not None:
      return LocalComputation(
          self._client.BuildWithRoot(root), is_compiled=False, backend=backend)
    else:
      return LocalComputation(
          self._client.Build(), is_compiled=False, backend=backend)

この Buildメソッドは、primitive_computation メソッドにて使われています。
@memoize
def primitive_computation(prim, *shapes, **kwargs):
  c = xb.make_computation_builder("primitive_computation")
  xla_args = map(c.ParameterWithShape, shapes)
  xla_result = translation_rule(prim)(c, *xla_args, **kwargs)
  try:
    return c.Build()
  except RuntimeError as e:
    # try for a better error message by using the abstract_eval checks
    prim.abstract_eval(*map(aval_from_xla_shape, shapes), **kwargs)
    raise e