Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

Google/jax のソースコード解析(その4)


primitive_computationメソッドは、xla_primitive_callableメソッドで次のように使われています。

def xla_primitive_callable(prim, *abstract_args, **kwargs):
  shapes = map(xla_shape, abstract_args)
  built_c = primitive_computation(prim, *shapes, **kwargs)
  result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
  handle_result = result_handler(result_shape)
  compiled = built_c.Compile(shapes, xb.get_compile_options())
  return partial(execute_compiled_primitive, compiled, handle_result)

build_c.Compileにて、コンパイルしています。

Compileメソッドは、XLAのxla_client.pyにて定義されています。

  def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None):
    """Compiles an un-compiled local computation.
    Local computations are the result of a "LocalComputationBuild'ing" process
    -- they start in uncompiled form, and via a call to Compile() turn into a
    compiled local computation.
    Raises:
      ValueError: if this is already a compiled local computation.
    Arguments:
      argument_shapes: parameter shapes -- they are first laid out by layout_fn
        if layout_fn is provided. Otherwise, the default layout for those shapes
        will be used.
      compile_options: options to use for compilation, includes an optional
        laid out result shape for the computation.
      layout_fn: lambda that is used to lay out the argument/result shapes.
    Returns:
      A newly *compiled* local computation instance.
    """
    if self._is_compiled:
      raise ValueError('Attempt to compile a compiled local XLA computation.')

    result_shape = _wrap_shape(self.computation.GetReturnValueShape())

    if layout_fn:
      argument_shapes = [
          shape.map_leaves(layout_fn) for shape in argument_shapes
      ]
      result_shape = result_shape.map_leaves(layout_fn)

    argument_shapes = list(argument_shapes)

    compile_options = compile_options or CompileOptions()
    compile_options.result_shape = result_shape
    if self._backend.backend_type == BackendType.XRT:
      c = self.computation.CompileForXrt(
          argument_shapes, _maybe_encode_string(self._backend.target))
    else:
      c = self.computation.Compile(argument_shapes, compile_options)
    return LocalComputation(c, is_compiled=True, backend=self._backend)

バックエンドによって、XRTでは self.computation.CompileForXrtメソッドで、
そうでないときは self.computation.Compile を。

おー、いつの間にか、CompileForXrt なるものが追加されたんだね。