今週の土曜日は、fpgax #11 + TFUG ハード部:DNN専用ハードについて語る会ね。
あたしも「TensorFlow XLA:XLAとは、から、最近の利用事例について」について、お話しますが、
Google/jax
については、ソースコードの頻繁な更新があるので、今回は対象外としました。
しかしながら、それじゃ、最新の利用事例としては、ちょっとと思いましたので、
今週のブログは「Google/jaxのソースコード解析」とします。
しかしながら、それじゃ、最新の利用事例としては、ちょっとと思いましたので、
今週のブログは「Google/jaxのソースコード解析」とします。
"""Interface and utility functions to XLA. This module wraps the XLA client(s) and builders to standardize their interfaces and provide some automatic type mapping logic for converting between Numpy and XLA. There are also a handful of related casting utilities. """
最初は、オプションの内容をチェック。
FLAGS = flags.FLAGS flags.DEFINE_bool('jax_enable_x64', strtobool(os.getenv('JAX_ENABLE_X64', "False")), 'Enable 64-bit types to be used.') flags.DEFINE_string('jax_dump_hlo_graph', None, 'Regexp of HLO graphs to dump.') flags.DEFINE_bool('jax_hlo_profile', False, 'Enables HLO profiling mode.') flags.DEFINE_string('jax_dump_hlo_unoptimized', None, 'Dirpath for unoptimized HLO dump.') flags.DEFINE_string('jax_dump_hlo_optimized', None, 'Dirpath for optimized HLO dump.') flags.DEFINE_string('jax_dump_hlo_per_pass', None, 'Dirpath for per-pass HLO dump.') flags.DEFINE_integer('jax_replica_count', 1, 'Replica count for computations.') flags.DEFINE_enum( 'jax_xla_backend', 'xla', ['xla', 'xrt'], 'Either "xla" for the XLA service directly, or "xrt" for an XRT backend.') flags.DEFINE_string( 'jax_backend_target', 'local', 'Either "local" or "rpc:address" to connect to a remote service target.') flags.DEFINE_string( 'jax_platform_name', '', 'Platform name for XLA. The default is to attempt to use a ' 'GPU if available, but fall back to CPU otherwise. To set ' 'the platform manually, pass "Host" for CPU or "CUDA" for ' 'GPU.')
jax_dump_hloが XLAのHLO関連。
これらの処理は、get_compile_optionsメソッドで処理しています。
def get_compile_options(): """Returns the compile options to use, as derived from flag values.""" compile_options = None if FLAGS.jax_dump_hlo_graph is not None: compile_options = get_xla_client().CompileOptions() compile_options.generate_hlo_graph = FLAGS.jax_dump_hlo_graph if FLAGS.jax_hlo_profile: compile_options = compile_options or get_xla_client().CompileOptions() compile_options.hlo_profile = True if FLAGS.jax_dump_hlo_unoptimized: compile_options = compile_options or get_xla_client().CompileOptions() path = _hlo_path(FLAGS.jax_dump_hlo_unoptimized, 'hlo_unoptimized') compile_options.dump_unoptimized_hlo_proto_to = path if FLAGS.jax_dump_hlo_optimized: compile_options = compile_options or get_xla_client().CompileOptions() path = _hlo_path(FLAGS.jax_dump_hlo_optimized, 'hlo_optimized') compile_options.dump_optimized_hlo_proto_to = path if FLAGS.jax_dump_hlo_per_pass: compile_options = compile_options or get_xla_client().CompileOptions() path = _hlo_path(FLAGS.jax_dump_hlo_per_pass, 'hlo_per_pass') compile_options.dump_per_pass_hlo_proto_to = path return compile_options
・HLOのグラフをダンプするかどうか? ・HLOのプロファイル ・HLOの最適化か関連でのダンプ ・HLO処理中のpass(最適化)でのダンプ
XLA/HLOでの デバッグ ようですね。