Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

Mesh TensorFlow の中身を覗いてみた

@Vengineerの戯言 : Twitter
SystemVerilogの世界へようこそすべては、SystemC v0.9公開から始まった 

Mesh TensorFlow の github の眺めたら、書いてあった。

  • The parameters of the model do not fit on one device
  • An example is so large that the activations do not fit on one device
  • Lower-latency parallel inference (at batch size 1).

github.com

最初のモデルのパラメータが大きい時、

2番目のアクティベーションが1つのデバイスにフィットしないぐらい大きい時、

は分かるんだけど、

3番目の Lower-latency parallel inference もそうなのね。。。

 

お、これって、Groq や Graphcore の Inference のアプローチと同じじゃん。。。

 

バイスの設定って、こんな感じに。4つのデバイス

devices = ["gpu:0", "gpu:1", "gpu:2", "gpu:3"]
mesh_shape = [("all_processors", 4)]
layout_rules = [("batch", "all_processors")]
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                                              mesh_shape, layout_rules, devices)
lowering = mtf.Lowering(graph, {mesh:mesh_impl})
tf_update_ops = [lowering.lowered_operation(update_w1_op),
                            lowering.lowered_operation(update_w2_op)]

 この中の、mtf.Lowering がポイントなんだろうか?

 

Loweringの ここ にあった。

def __init__(self, graph, mesh_to_impl, autostack=True, log_file=None):
    """Creates a Lowering of a Graph.
    Args:
        graph: Graph.
       mesh_to_impl: {Mesh: MeshImpl}. Keys are the Mesh's in the graph and
their values are MeshImpl's, which map Tensor Dimension names to
Mesh Dimension names.
        autostack: a boolean. If True, then the graph gets rewritten to
reduce the number of variables (see rewrite_stack_variables()).
This is a helpful performance optimization for large meshes.
For more fine-grained control, you can call
graph.rewrite_stack_variables() yourself before creating the Lowering.
        log_file: an optional string. If provided, information about the variables
and operations will also be logged to this file.

MeshImplとしては、次の2つがあって、SimdMeshImpl が TPU 用

グラフ を mesh_to_impl によって、Lowering するんだけど、autostack を True (デフォルトでも True) にすると、variables の数を少なくするように、書き換えるようですね。

再度見てみると、 

lowering = mtf.Lowering(graph, {mesh:mesh_impl})
tf_update_ops = [lowering.lowered_operation(update_w1_op),
                            lowering.lowered_operation(update_w2_op)]

 lowering.lowered_operation メソッドが下記のように、self.operations[op]を返しているだけですね。ということは、self.operations はどうやっているの?

def lowered_operation(self, op):
    return self.operations[op]

 __init__ メソッドの中で、

for op in graph.operations:
# tf.logging.info("Lowering operation %s" % op.to_string)
with tf.name_scope(op.name):
    op.lower(self)

 のように、op.lower(self) がポイントかな。。。

Operation クラスの定義では、

def lower(self, lowering):
    raise NotImplementedError("Lower not implemented")

 のように未実装なので、Operation クラスを継承するクラスで定義されているっぽい。

例えば、SlicewiseOperation クラス の lower メソッドは、こんな感じ。

def lower(self, lowering):
    # Check that only splittable dims are split
    mesh_impl = lowering.mesh_impl(self)
    for t in self.inputs + self.outputs:
        layout = mesh_impl.tensor_layout(t)
        for d, mesh_axis in zip(t.shape.dims, layout.tensor_axis_to_mesh_axis):
            if mesh_axis is not None and d.name not in self._splittable_dims:
                raise ValueError("dimension %s is not declared as splittable" % d)
            values = mesh_impl.slicewise(
            self._tf_fn, *[lowering.tensors[x] for x in self.inputs])
        if len(self.outputs) == 1:
            values = values,
        for output, value in zip(self.outputs, values):
            lowering.set_tensor_lowering(output, value)

 この中に出てくる set_tensor_lowering は、Lowering クラスのメソッドの模様。

def set_tensor_lowering(self, tensor, laid_out_tensor):
    self.verify_slice_shapes(tensor, laid_out_tensor)
    self.tensors[tensor] = laid_out_tensor