こういうちょっとしたものも、エコシステムでは非常に大切なのよね。。。
特に、zero-copy conversionが大切。。。
フレームワーク間で頻繁にコピーが発生していちゃね。。。。
最後は、TVM の Packed Function に渡して、実行できるんだよね。
引用 n = tvm.convert(56) X = tvm.placeholder((n,n), name='X') Y = tvm.placeholder((n,n), name='Y') k = tvm.reduce_axis((0, n), name='k') Z = tvm.compute((n,n), lambda i,j : tvm.sum(X[i,k]*Y[k,j], axis=k)) s = tvm.create_schedule(Z.op) fmm = tvm.build(s, [X, Y, Z], target_host='llvm', name='fmm')
のように、TVM で Packed Function (fmm) を生成し、
引用 from tvm.contrib.dlpack import to_pytorch_func # fmm is the previously built TVM function (Python function) # fmm is the wrapped TVM function (Python function) fmm_pytorch = to_pytorch_func(fmm) z2 = torch.empty(56,56) fmm_pytorch(x, y, z2) np.testing.assert_allclose(z.numpy(), z2.numpy())のように、to_pytorch_func にて、TVM の Packed Function を Pytorch の関数に変換するだけだって。。。
MxNet でもできるって、
引用 import mxnet from tvm.contrib.mxnet import to_mxnet_func ctx = mxnet.cpu(0) x = mxnet.nd.uniform(shape=(56,56), ctx=ctx) y = mxnet.nd.uniform(shape=(56,56), ctx=ctx) z = mxnet.nd.empty(shape=(56,56), ctx=ctx) f = tvm.build(s, [X, Y, Z], target_host='llvm', name='f') f_mxnet = to_mxnet_func(f) f_mxnet(x, y, z) np.testing.assert_allclose(z.asnumpy(), x.asnumpy().dot(y.asnumpy()))
to_mxnet_func を使って、変換していますね。