Vengineerの妄想

人生を妄想しています。

PyTorch/MXNet と TVM と繋ぐには?



PyTorch と TVM を繋ぐために、
DLPack: Open In Memory Tensor Structure といいうものを利用するっていうこと。

x86やCUDAをPythonフレームワークだけを使っているのであれば、全く気にならないが、
組み込み系に持っていこうとすると、ぶつかるんだよね。こういうのって。

こういうちょっとしたものも、エコシステムでは非常に大切なのよね。。。

https://tvm.ai/images/pytorch-dlpack/dlpack.png

特に、zero-copy conversionが大切。。。

フレームワーク間で頻繁にコピーが発生していちゃね。。。。

最後は、TVM の Packed Function に渡して、実行できるんだよね。

引用
    n = tvm.convert(56)
    X = tvm.placeholder((n,n), name='X')
    Y = tvm.placeholder((n,n), name='Y')

    k = tvm.reduce_axis((0, n), name='k')
    Z = tvm.compute((n,n), lambda i,j : tvm.sum(X[i,k]*Y[k,j], axis=k))
    s = tvm.create_schedule(Z.op)
    fmm = tvm.build(s, [X, Y, Z], target_host='llvm', name='fmm')

のように、TVM で Packed Function (fmm) を生成し、

引用
    from tvm.contrib.dlpack import to_pytorch_func
    # fmm is the previously built TVM function (Python function)
    # fmm is the wrapped TVM function (Python function)
    fmm_pytorch = to_pytorch_func(fmm)
    z2 = torch.empty(56,56)
    fmm_pytorch(x, y, z2)
    np.testing.assert_allclose(z.numpy(), z2.numpy())
のように、to_pytorch_func にて、TVM の Packed Function を Pytorch の関数に変換するだけだって。。。

MxNet でもできるって、
引用
    import mxnet
    from tvm.contrib.mxnet import to_mxnet_func
    ctx = mxnet.cpu(0)
    x = mxnet.nd.uniform(shape=(56,56), ctx=ctx)
    y = mxnet.nd.uniform(shape=(56,56), ctx=ctx)
    z = mxnet.nd.empty(shape=(56,56), ctx=ctx)
    f = tvm.build(s, [X, Y, Z], target_host='llvm', name='f')
    f_mxnet = to_mxnet_func(f)
    f_mxnet(x, y, z)
    np.testing.assert_allclose(z.asnumpy(), x.asnumpy().dot(y.asnumpy()))

to_mxnet_func を使って、変換していますね。