xla_primitive_callableメソッドの
def xla_primitive_callable(prim, *abstract_args, **kwargs): shapes = map(xla_shape, abstract_args) built_c = primitive_computation(prim, *shapes, **kwargs) result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape()) handle_result = result_handler(result_shape) compiled = built_c.Compile(shapes, xb.get_compile_options()) return partial(execute_compiled_primitive, compiled, handle_result)
最後に、execute_compiled_primitiveメソッドなるものがあります。
def execute_compiled_primitive(compiled, result_handler, *args): input_bufs = [device_put(canonicalize_pyval_dtype(x)) for x in args] return result_handler(compiled.Execute(input_bufs, not core.skip_checks))
compiled.Executeメソッドは、local_computation_builder.ccにて、次のように定義されています。
StatusOr<LocalShapedBuffer*> CompiledLocalComputation::Execute( absl::Span<LocalShapedBuffer* const> argument_handles) { TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); StatusOr<int> device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0); StatusOr<ScopedShapedBuffer> result_buffer_status; if (!device_ordinal_status.ok()) { result_buffer_status = device_ordinal_status.status(); } else { const int device_ordinal = device_ordinal_status.ValueOrDie(); VLOG(3) << "Replica 0 mapped to device ordinal for execution: " << device_ordinal; std::vector<const ShapedBuffer*> argument_buffers; argument_buffers.reserve(argument_handles.size()); for (auto& handle : argument_handles) { argument_buffers.push_back(handle->shaped_buffer()); } DeviceAssignment device_assignment = client->backend() .computation_placer() ->AssignDevices(1, /*computation_count=*/1) .ConsumeValueOrDie(); ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); options.set_allocator(client->backend().memory_allocator()); options.set_intra_op_thread_pool( client->backend().eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); result_buffer_status = executable_->Run(argument_buffers, options); } if (!result_buffer_status.ok()) { return InternalError( "Failed running replica 0 (other replicas may have failed as well): " "%s.", result_buffer_status.status().ToString()); } return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); }
executable_->Run にて、実行しています。
xla_primitive_callableメソッドは、apply_primitiveメソッドから呼ばれいています。
def apply_primitive(prim, *args, **kwargs): abstract_args = map(abstractify, args) compiled_fun = xla_primitive_callable(prim, *abstract_args, **kwargs) return compiled_fun(*args)
standard_primitiveメソッド、どうやらこれが基本的なものらしいですね。
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None): prim = Primitive(name) prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule)) xla.translations[prim] = translation_rule or partial(standard_translate, name) return prim