@memoize_thunk def get_xla_client(): return _get_xla_client(FLAGS.jax_xla_backend, FLAGS.jax_platform_name, FLAGS.jax_replica_count)
_get_xla_clientメソッドをFLAGSで設定したパラメータを引数を呼んでいます。
@memoize_thunkは、以下のように定義されていて、関数の戻り値をキャッシュしています。
def memoize_thunk(func): cached = [] return lambda: cached[0] if cached else (cached.append(func()) or cached[0])
_get_xla_clientメソッドは、次のようになっています。
backend_name としては、'xla' と 'xrt'、
platform_nameは、XLAのバックエンド名(CPU、CUDAなど、TPUはどうなるの?)
最後は、replica_countは、レプリカ数。。。
backend_name としては、'xla' と 'xrt'、
platform_nameは、XLAのバックエンド名(CPU、CUDAなど、TPUはどうなるの?)
最後は、replica_countは、レプリカ数。。。
backend_nameに、'xla'を指定した場合のみ、いろいろと設定していますね。
'xrt'の場合は、xla_clientのデフォルト値のままなのね。
'xrt'の場合は、xla_clientのデフォルト値のままなのね。
def _get_xla_client(backend_name, platform_name, replica_count): """Configures and returns a handle to the XLA client. Args: backend_name: backend name, 'xla' or 'xrt' platform_name: platform name for XLA backend replica_count: number of computation replicas with which to configure the backend library. Returns: A client library module, or an object that behaves identically to one. """ global _platform_name xla_client.initialize_replica_count(replica_count) if backend_name == 'xla': if platform_name: xla_client.initialize_platform_name(platform_name) _platform_name = platform_name else: try: xla_client.initialize_platform_name('CUDA') _platform_name = 'CUDA' except RuntimeError: warnings.warn('No GPU found, falling back to CPU.') xla_client.initialize_platform_name('Host') _platform_name = 'Host' return xla_client
xla_client は、install_xla_in_source_tree.shの最後に、
-e '/from tensorflow.compiler.xla.service import hlo_pb2/d' \ < "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_client.py)" \ > "${TARGET}/jaxlib/xla_client.py"とありますね。TensorFlowのソースコードのcompiler/xla/python/xla_client.pyを使っていることになります。
昨日のget_compile_optionsメソッドで、CompileOptionsを使っていますね。
class CompileOptions(object): """Python object for XLA compile options. These options can be passed to the 'compile' step when using a local XLA client. """ def __init__(self): self.generate_hlo_graph = None self.dump_optimized_hlo_proto_to = None self.dump_unoptimized_hlo_proto_to = None self.dump_per_pass_hlo_proto_to = None self.hlo_profile = False
お、これって、昨日の get_compile_options メソッドで設定していたものと同じですね。
def get_compile_options(): """Returns the compile options to use, as derived from flag values.""" compile_options = None if FLAGS.jax_dump_hlo_graph is not None: compile_options = get_xla_client().CompileOptions() compile_options.generate_hlo_graph = FLAGS.jax_dump_hlo_graph if FLAGS.jax_hlo_profile: compile_options = compile_options or get_xla_client().CompileOptions() compile_options.hlo_profile = True if FLAGS.jax_dump_hlo_unoptimized: compile_options = compile_options or get_xla_client().CompileOptions() path = _hlo_path(FLAGS.jax_dump_hlo_unoptimized, 'hlo_unoptimized') compile_options.dump_unoptimized_hlo_proto_to = path if FLAGS.jax_dump_hlo_optimized: compile_options = compile_options or get_xla_client().CompileOptions() path = _hlo_path(FLAGS.jax_dump_hlo_optimized, 'hlo_optimized') compile_options.dump_optimized_hlo_proto_to = path if FLAGS.jax_dump_hlo_per_pass: compile_options = compile_options or get_xla_client().CompileOptions() path = _hlo_path(FLAGS.jax_dump_hlo_per_pass, 'hlo_per_pass') compile_options.dump_per_pass_hlo_proto_to = path return compile_options