@Vengineerの戯言 : Twitter
SystemVerilogの世界へようこそ、すべては、SystemC v0.9公開から始まった
Mesh TensorFlow の github の眺めたら、書いてあった。
- The parameters of the model do not fit on one device
- An example is so large that the activations do not fit on one device
- Lower-latency parallel inference (at batch size 1).
最初のモデルのパラメータが大きい時、
2番目のアクティベーションが1つのデバイスにフィットしないぐらい大きい時、
は分かるんだけど、
3番目の Lower-latency parallel inference もそうなのね。。。
お、これって、Groq や Graphcore の Inference のアプローチと同じじゃん。。。
devices = ["gpu:0", "gpu:1", "gpu:2", "gpu:3"]
mesh_shape = [("all_processors", 4)]
layout_rules = [("batch", "all_processors")]
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
mesh_shape, layout_rules, devices)
lowering = mtf.Lowering(graph, {mesh:mesh_impl})
tf_update_ops = [lowering.lowered_operation(update_w1_op),
lowering.lowered_operation(update_w2_op)]
この中の、mtf.Lowering がポイントなんだろうか?
Loweringの ここ にあった。
def __init__(self, graph, mesh_to_impl, autostack=True, log_file=None):
"""Creates a Lowering of a Graph.
Args:
graph: Graph.
mesh_to_impl: {Mesh: MeshImpl}. Keys are the Mesh's in the graph and
their values are MeshImpl's, which map Tensor Dimension names to
Mesh Dimension names.
autostack: a boolean. If True, then the graph gets rewritten to
reduce the number of variables (see rewrite_stack_variables()).
This is a helpful performance optimization for large meshes.
For more fine-grained control, you can call
graph.rewrite_stack_variables() yourself before creating the Lowering.
log_file: an optional string. If provided, information about the variables
and operations will also be logged to this file.
MeshImplとしては、次の2つがあって、SimdMeshImpl が TPU 用
グラフ を mesh_to_impl によって、Lowering するんだけど、autostack を True (デフォルトでも True) にすると、variables の数を少なくするように、書き換えるようですね。
再度見てみると、
lowering = mtf.Lowering(graph, {mesh:mesh_impl})
tf_update_ops = [lowering.lowered_operation(update_w1_op),
lowering.lowered_operation(update_w2_op)]
lowering.lowered_operation メソッドが下記のように、self.operations[op]を返しているだけですね。ということは、self.operations はどうやっているの?
def lowered_operation(self, op):
return self.operations[op]
__init__ メソッドの中で、
for op in graph.operations:
# tf.logging.info("Lowering operation %s" % op.to_string)
with tf.name_scope(op.name):
op.lower(self)
のように、op.lower(self) がポイントかな。。。
def lower(self, lowering):
raise NotImplementedError("Lower not implemented")
のように未実装なので、Operation クラスを継承するクラスで定義されているっぽい。
例えば、SlicewiseOperation クラス の lower メソッドは、こんな感じ。
def lower(self, lowering):
# Check that only splittable dims are split
mesh_impl = lowering.mesh_impl(self)
for t in self.inputs + self.outputs:
layout = mesh_impl.tensor_layout(t)
for d, mesh_axis in zip(t.shape.dims, layout.tensor_axis_to_mesh_axis):
if mesh_axis is not None and d.name not in self._splittable_dims:
raise ValueError("dimension %s is not declared as splittable" % d)
values = mesh_impl.slicewise(
self._tf_fn, *[lowering.tensors[x] for x in self.inputs])
if len(self.outputs) == 1:
values = values,
for output, value in zip(self.outputs, values):
lowering.set_tensor_lowering(output, value)
この中に出てくる set_tensor_lowering は、Lowering クラスのメソッドの模様。
def set_tensor_lowering(self, tensor, laid_out_tensor):
self.verify_slice_shapes(tensor, laid_out_tensor)
self.tensors[tensor] = laid_out_tensor