Vengineerの妄想(準備期間)

人生は短いけど、長いです。人生を楽しみましょう!

Google/jax のソースコード解析(その2)



昨日出てきたget_compile_optionsに出てきた、get_xla_clientは、xla_bridge.py で次のように定義されています

@memoize_thunk
def get_xla_client():
  return _get_xla_client(FLAGS.jax_xla_backend,
                         FLAGS.jax_platform_name,
                         FLAGS.jax_replica_count)

_get_xla_clientメソッドをFLAGSで設定したパラメータを引数を呼んでいます。

@memoize_thunkは、以下のように定義されていて、関数の戻り値をキャッシュしています。
def memoize_thunk(func):
  cached = []
  return lambda: cached[0] if cached else (cached.append(func()) or cached[0]) 

_get_xla_clientメソッドは、次のようになっています。
backend_name としては、'xla' と 'xrt'、
platform_nameは、XLAのバックエンド名(CPU、CUDAなど、TPUはどうなるの?)
最後は、replica_countは、レプリカ数。。。

backend_nameに、'xla'を指定した場合のみ、いろいろと設定していますね。
'xrt'の場合は、xla_clientのデフォルト値のままなのね。

def _get_xla_client(backend_name, platform_name, replica_count):
  """Configures and returns a handle to the XLA client.
  Args:
    backend_name: backend name, 'xla' or 'xrt'
    platform_name: platform name for XLA backend
    replica_count: number of computation replicas with which to configure the
      backend library.
  Returns:
    A client library module, or an object that behaves identically to one.
  """
  global _platform_name
  xla_client.initialize_replica_count(replica_count)
  if backend_name == 'xla':
    if platform_name:
      xla_client.initialize_platform_name(platform_name)
      _platform_name = platform_name
    else:
      try:
        xla_client.initialize_platform_name('CUDA')
        _platform_name = 'CUDA'
      except RuntimeError:
        warnings.warn('No GPU found, falling back to CPU.')
        xla_client.initialize_platform_name('Host')
        _platform_name = 'Host'
  return xla_client

xla_client は、install_xla_in_source_tree.shの最後に、
  -e '/from tensorflow.compiler.xla.service import hlo_pb2/d' \
  < "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_client.py)" \
  > "${TARGET}/jaxlib/xla_client.py"
とありますね。TensorFlowのソースコードcompiler/xla/python/xla_client.pyを使っていることになります。

昨日のget_compile_optionsメソッドで、CompileOptionsを使っていますね。

class CompileOptions(object):
  """Python object for XLA compile options.
  These options can be passed to the 'compile' step when using a local XLA
  client.
  """

  def __init__(self):
    self.generate_hlo_graph = None
    self.dump_optimized_hlo_proto_to = None
    self.dump_unoptimized_hlo_proto_to = None
    self.dump_per_pass_hlo_proto_to = None
    self.hlo_profile = False

お、これって、昨日の get_compile_options メソッドで設定していたものと同じですね。
def get_compile_options():
  """Returns the compile options to use, as derived from flag values."""
  compile_options = None
  if FLAGS.jax_dump_hlo_graph is not None:
    compile_options = get_xla_client().CompileOptions()
    compile_options.generate_hlo_graph = FLAGS.jax_dump_hlo_graph
  if FLAGS.jax_hlo_profile:
    compile_options = compile_options or get_xla_client().CompileOptions()
    compile_options.hlo_profile = True
  if FLAGS.jax_dump_hlo_unoptimized:
    compile_options = compile_options or get_xla_client().CompileOptions()
    path = _hlo_path(FLAGS.jax_dump_hlo_unoptimized, 'hlo_unoptimized')
    compile_options.dump_unoptimized_hlo_proto_to = path
  if FLAGS.jax_dump_hlo_optimized:
    compile_options = compile_options or get_xla_client().CompileOptions()
    path = _hlo_path(FLAGS.jax_dump_hlo_optimized, 'hlo_optimized')
    compile_options.dump_optimized_hlo_proto_to = path
  if FLAGS.jax_dump_hlo_per_pass:
    compile_options = compile_options or get_xla_client().CompileOptions()
    path = _hlo_path(FLAGS.jax_dump_hlo_per_pass, 'hlo_per_pass')
    compile_options.dump_per_pass_hlo_proto_to = path
  return compile_options