Vengineerの妄想(準備期間)

人生は短いけど、長いです。人生を楽しみましょう!

Google/jax のソースコード解析(その1)



あたしも「TensorFlow XLA:XLAとは、から、最近の利用事例について」について、お話しますが、

Google/jax


については、ソースコードの頻繁な更新があるので、今回は対象外としました。
しかしながら、それじゃ、最新の利用事例としては、ちょっとと思いましたので、
今週のブログは「Google/jaxのソースコード解析」とします。

Google/jax の XLA 関連は、xla_bridge.pyにありますので、まずはこの「xla_bridge.py」を見ていきます。

"""Interface and utility functions to XLA.
This module wraps the XLA client(s) and builders to standardize their interfaces
and provide some automatic type mapping logic for converting between Numpy and
XLA. There are also a handful of related casting utilities.
"""

最初は、オプションの内容をチェック。

FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_enable_x64',
                  strtobool(os.getenv('JAX_ENABLE_X64', "False")),
                  'Enable 64-bit types to be used.')
flags.DEFINE_string('jax_dump_hlo_graph', None, 'Regexp of HLO graphs to dump.')
flags.DEFINE_bool('jax_hlo_profile', False, 'Enables HLO profiling mode.')
flags.DEFINE_string('jax_dump_hlo_unoptimized', None,
                    'Dirpath for unoptimized HLO dump.')
flags.DEFINE_string('jax_dump_hlo_optimized', None,
                    'Dirpath for optimized HLO dump.')
flags.DEFINE_string('jax_dump_hlo_per_pass', None,
                    'Dirpath for per-pass HLO dump.')
flags.DEFINE_integer('jax_replica_count', 1, 'Replica count for computations.')
flags.DEFINE_enum(
    'jax_xla_backend', 'xla', ['xla', 'xrt'],
    'Either "xla" for the XLA service directly, or "xrt" for an XRT backend.')
flags.DEFINE_string(
    'jax_backend_target', 'local',
    'Either "local" or "rpc:address" to connect to a remote service target.')
flags.DEFINE_string(
    'jax_platform_name', '',
    'Platform name for XLA. The default is to attempt to use a '
    'GPU if available, but fall back to CPU otherwise. To set '
    'the platform manually, pass "Host" for CPU or "CUDA" for '
    'GPU.')

jax_dump_hloが XLAのHLO関連。

これらの処理は、get_compile_optionsメソッドで処理しています。
def get_compile_options():
  """Returns the compile options to use, as derived from flag values."""
  compile_options = None
  if FLAGS.jax_dump_hlo_graph is not None:
    compile_options = get_xla_client().CompileOptions()
    compile_options.generate_hlo_graph = FLAGS.jax_dump_hlo_graph
  if FLAGS.jax_hlo_profile:
    compile_options = compile_options or get_xla_client().CompileOptions()
    compile_options.hlo_profile = True
  if FLAGS.jax_dump_hlo_unoptimized:
    compile_options = compile_options or get_xla_client().CompileOptions()
    path = _hlo_path(FLAGS.jax_dump_hlo_unoptimized, 'hlo_unoptimized')
    compile_options.dump_unoptimized_hlo_proto_to = path
  if FLAGS.jax_dump_hlo_optimized:
    compile_options = compile_options or get_xla_client().CompileOptions()
    path = _hlo_path(FLAGS.jax_dump_hlo_optimized, 'hlo_optimized')
    compile_options.dump_optimized_hlo_proto_to = path
  if FLAGS.jax_dump_hlo_per_pass:
    compile_options = compile_options or get_xla_client().CompileOptions()
    path = _hlo_path(FLAGS.jax_dump_hlo_per_pass, 'hlo_per_pass')
    compile_options.dump_per_pass_hlo_proto_to = path
  return compile_options

 ・HLOのグラフをダンプするかどうか?
 ・HLOのプロファイル
 ・HLOの最適化か関連でのダンプ
 ・HLO処理中のpass(最適化)でのダンプ

XLA/HLOでの デバッグ ようですね。