Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

Google/jax のソースコード解析(その5)


xla_primitive_callableメソッド
def xla_primitive_callable(prim, *abstract_args, **kwargs):
  shapes = map(xla_shape, abstract_args)
  built_c = primitive_computation(prim, *shapes, **kwargs)
  result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
  handle_result = result_handler(result_shape)
  compiled = built_c.Compile(shapes, xb.get_compile_options())
  return partial(execute_compiled_primitive, compiled, handle_result)

最後に、execute_compiled_primitiveメソッドなるものがあります。

def execute_compiled_primitive(compiled, result_handler, *args):
  input_bufs = [device_put(canonicalize_pyval_dtype(x)) for x in args]
  return result_handler(compiled.Execute(input_bufs, not core.skip_checks))

compiled.Executeメソッドは、local_computation_builder.ccにて、次のように定義されています。

StatusOr<LocalShapedBuffer*> CompiledLocalComputation::Execute(
    absl::Span<LocalShapedBuffer* const> argument_handles) {
  TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
  StatusOr<int> device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0);
  StatusOr<ScopedShapedBuffer> result_buffer_status;
  if (!device_ordinal_status.ok()) {
    result_buffer_status = device_ordinal_status.status();
  } else {
    const int device_ordinal = device_ordinal_status.ValueOrDie();
    VLOG(3) << "Replica 0 mapped to device ordinal for execution: "
            << device_ordinal;

    std::vector<const ShapedBuffer*> argument_buffers;
    argument_buffers.reserve(argument_handles.size());
    for (auto& handle : argument_handles) {
      argument_buffers.push_back(handle->shaped_buffer());
    }

    DeviceAssignment device_assignment =
        client->backend()
            .computation_placer()
            ->AssignDevices(1, /*computation_count=*/1)
            .ConsumeValueOrDie();

    ExecutableRunOptions options;
    options.set_device_ordinal(device_ordinal);
    options.set_allocator(client->backend().memory_allocator());
    options.set_intra_op_thread_pool(
        client->backend().eigen_intra_op_thread_pool_device());
    options.set_device_assignment(&device_assignment);

    result_buffer_status = executable_->Run(argument_buffers, options);
  }

  if (!result_buffer_status.ok()) {
    return InternalError(
        "Failed running replica 0 (other replicas may have failed as well): "
        "%s.",
        result_buffer_status.status().ToString());
  }
  return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie());
}

executable_->Run にて、実行しています。

xla_primitive_callableメソッドは、apply_primitiveメソッドから呼ばれいています。
def apply_primitive(prim, *args, **kwargs):
  abstract_args = map(abstractify, args)
  compiled_fun = xla_primitive_callable(prim, *abstract_args, **kwargs)
  return compiled_fun(*args)

standard_primitiveメソッド、どうやらこれが基本的なものらしいですね。
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
  prim = Primitive(name)
  prim.def_impl(partial(xla.apply_primitive, prim))
  prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
  xla.translations[prim] = translation_rule or partial(standard_translate, name)
  return prim