primitive_computationメソッドは、xla_primitive_callableメソッドで次のように使われています。
def xla_primitive_callable(prim, *abstract_args, **kwargs): shapes = map(xla_shape, abstract_args) built_c = primitive_computation(prim, *shapes, **kwargs) result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape()) handle_result = result_handler(result_shape) compiled = built_c.Compile(shapes, xb.get_compile_options()) return partial(execute_compiled_primitive, compiled, handle_result)
build_c.Compileにて、コンパイルしています。
Compileメソッドは、XLAのxla_client.pyにて定義されています。
def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): """Compiles an un-compiled local computation. Local computations are the result of a "LocalComputationBuild'ing" process -- they start in uncompiled form, and via a call to Compile() turn into a compiled local computation. Raises: ValueError: if this is already a compiled local computation. Arguments: argument_shapes: parameter shapes -- they are first laid out by layout_fn if layout_fn is provided. Otherwise, the default layout for those shapes will be used. compile_options: options to use for compilation, includes an optional laid out result shape for the computation. layout_fn: lambda that is used to lay out the argument/result shapes. Returns: A newly *compiled* local computation instance. """ if self._is_compiled: raise ValueError('Attempt to compile a compiled local XLA computation.') result_shape = _wrap_shape(self.computation.GetReturnValueShape()) if layout_fn: argument_shapes = [ shape.map_leaves(layout_fn) for shape in argument_shapes ] result_shape = result_shape.map_leaves(layout_fn) argument_shapes = list(argument_shapes) compile_options = compile_options or CompileOptions() compile_options.result_shape = result_shape if self._backend.backend_type == BackendType.XRT: c = self.computation.CompileForXrt( argument_shapes, _maybe_encode_string(self._backend.target)) else: c = self.computation.Compile(argument_shapes, compile_options) return LocalComputation(c, is_compiled=True, backend=self._backend)
バックエンドによって、XRTでは self.computation.CompileForXrtメソッドで、
そうでないときは self.computation.Compile を。
そうでないときは self.computation.Compile を。
おー、いつの間にか、CompileForXrt なるものが追加されたんだね。